<!DOCTYPE html>
<html lang="en">
<head>
    <meta http-equiv="content-type" content="text/html;charset=utf-8"/>
    <meta name="viewport" content="width=device-width, initial-scale=1.0"/>
    <meta name="description" content="This is a PyTorch/Triton implementation of Flash Attention 2 with explanations."/>

    <meta name="twitter:card" content="summary"/>
    <meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
    <meta name="twitter:title" content="Flash Attention"/>
    <meta name="twitter:description" content="This is a PyTorch/Triton implementation of Flash Attention 2 with explanations."/>
    <meta name="twitter:site" content="@labmlai"/>
    <meta name="twitter:creator" content="@labmlai"/>

    <meta property="og:url" content="https://nn.labml.ai/transformers/flash/index.html"/>
    <meta property="og:title" content="Flash Attention"/>
    <meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
    <meta property="og:site_name" content="Flash Attention"/>
    <meta property="og:type" content="object"/>
    <meta property="og:title" content="Flash Attention"/>
    <meta property="og:description" content="This is a PyTorch/Triton implementation of Flash Attention 2 with explanations."/>

    <title>Flash Attention</title>
    <link rel="shortcut icon" href="/icon.png"/>
    <link rel="stylesheet" href="../../pylit.css?v=1">
    <link rel="canonical" href="https://nn.labml.ai/transformers/flash/index.html"/>
    <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">

    <!-- Global site tag (gtag.js) - Google Analytics -->
    <script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
    <script>
        window.dataLayer = window.dataLayer || [];

        function gtag() {
            dataLayer.push(arguments);
        }

        gtag('js', new Date());

        gtag('config', 'G-4V3HC8HBLH');
    </script>
</head>
<body>
<div id='container'>
    <div id="background"></div>
    <div class='section'>
        <div class='docs'>
            <p>
                <a class="parent" href="/">home</a>
                <a class="parent" href="../index.html">transformers</a>
                <a class="parent" href="index.html">flash</a>
            </p>
            <p>
                <a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations" target="_blank">
                    <img alt="Github"
                         src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
                         style="max-width:100%;"/></a>
                <a href="https://twitter.com/labmlai" rel="nofollow" target="_blank">
                    <img alt="Twitter"
                         src="https://img.shields.io/twitter/follow/labmlai?style=social"
                         style="max-width:100%;"/></a>
            </p>
            <p>
                <a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/transformers/flash/__init__.py" target="_blank">
                    View code on Github</a>
            </p>
        </div>
    </div>
    <div class='section' id='section-0'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-0'>#</a>
            </div>
            <h1>Flash Attention</h1>
<p>Flash attention speeds up transformer attention mechanism by reducing the number of memory reads/writes between GPU high bandwidth memory (HBM) and GPU on-chip SRAM.</p>
<p>It&#x27;s introduced in paper <a href="https://arxiv.org/abs/2205.14135">FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness</a> and further optimized in paper <a href="https://arxiv.org/abs/2307.08691">FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning</a>. Official CUDA implementation can be found at <a href="https://github.com/Dao-AILab/flash-attention">Dao-AILab/flash-attention</a>.</p>
<p>Our implementation is based on the <a href="https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html">Triton&#x27;s example implementation</a>.</p>
<p><em>Note: You can click on the mathematical symbols or identifiers to highlight them</em>.</p>
<p>You can run <a href="./test.html">test.py</a> to see correctness and measure performance of this implementation.</p>
<h2>Forward pass</h2>
<p>Here&#x27;s the attention forward pass. The formulas represent a single attention head. <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8777699999999999em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqcj" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcw" style="">Q</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> is query vector (row vector) at position <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.65952em;vertical-align:0em;"></span><span class="mord coloredeq eqcz" style=""><span class="mord mathnormal" style="">i</span></span></span></span></span></span> and <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.969438em;vertical-align:-0.286108em;"></span><span class="mord coloredeq eqcg" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcv" style="margin-right:0.07153em">K</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span></span></span></span> and <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.969438em;vertical-align:-0.286108em;"></span><span class="mord coloredeq eqck" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcx" style="margin-right:0.22222em">V</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span></span></span></span> are the key and value row vectors at position <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.85396em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqda" style=""><span class="mord mathnormal" style="margin-right:0.05724em">j</span></span></span></span></span></span>. <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.83333em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqci" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.02778em">O</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> is the output vector at position <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.65952em;vertical-align:0em;"></span><span class="mord coloredeq eqcz" style=""><span class="mord mathnormal" style="">i</span></span></span></span></span></span>.</p>
<span ><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:12.791673em;vertical-align:-6.1458365em;"></span><span class="mord"><span class="mtable"><span class="col-align-r"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:6.6458365em;"><span style="top:-9.249606499999999em;"><span class="pstrut" style="height:3.518331em;"></span><span class="mord"><span class="mord coloredeq eqbu" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.05764em">S</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.05764em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span></span><span style="top:-7.5396015em;"><span class="pstrut" style="height:3.518331em;"></span><span class="mord"><span class="mord coloredeq eqch" style=""><span class="mord" style=""><span class="mord mathnormal" style="">L</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span><span style="top:-4.3074935em;"><span class="pstrut" style="height:3.518331em;"></span><span class="mord"><span class="mord coloredeq eqbt" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.13889em">P</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span></span><span style="top:-2.1214885em;"><span class="pstrut" style="height:3.518331em;"></span><span class="mord"><span class="mord coloredeq eqci" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.02778em">O</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span><span style="top:0.9137284999999997em;"><span class="pstrut" style="height:3.518331em;"></span><span class="mord"></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:6.1458365em;"><span></span></span></span></span></span><span class="col-align-l"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:6.6458365em;"><span style="top:-9.249606499999999em;"><span class="pstrut" style="height:3.518331em;"></span><span class="mord"><span class="mord"></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mord coloredeq eqbv" style=""><span class="mord mathnormal" style="margin-right:0.03588em">σ</span></span><span class="mord coloredeq eqcj" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcw" style="">Q</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mord coloredeq eqbx" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord coloredeq eqcg" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcv" style="margin-right:0.07153em">K</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.914561em;"><span style="top:-3.1362300000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.13889em">T</span></span></span></span></span></span></span></span></span></span></span><span style="top:-7.5396015em;"><span class="pstrut" style="height:3.518331em;"></span><span class="mord"><span class="mord"></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mop op-limits"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.050005em;"><span style="top:-1.8723309999999997em;margin-left:0em;"><span class="pstrut" style="height:3.05em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight coloredeq eqda" style=""><span class="mord mathnormal mtight" style="margin-right:0.05724em">j</span></span></span></span><span style="top:-3.0500049999999996em;"><span class="pstrut" style="height:3.05em;"></span><span><span class="mop op-symbol large-op">∑</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:1.413777em;"><span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord coloredeq eqcy" style=""><span class="mord mathnormal" style="">e</span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.891331em;"><span style="top:-3.113em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqbu" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.05764em">S</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3280857142857143em;"><span style="top:-2.357em;margin-left:-0.05764em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2818857142857143em;"><span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span><span style="top:-4.3074935em;"><span class="pstrut" style="height:3.518331em;"></span><span class="mord"><span class="mord"></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.518331em;"><span style="top:-2.3139999999999996em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord coloredeq eqch" style=""><span class="mord" style=""><span class="mord mathnormal" style="">L</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.677em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord"><span class="mord coloredeq eqcy" style=""><span class="mord mathnormal" style="">e</span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.841331em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqbu" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.05764em">S</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3280857142857143em;"><span style="top:-2.357em;margin-left:-0.05764em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2818857142857143em;"><span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.8360000000000001em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span><span style="top:-2.1214885em;"><span class="pstrut" style="height:3.518331em;"></span><span class="mord"><span class="mord"></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mop op-limits"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.050005em;"><span style="top:-1.8723309999999997em;margin-left:0em;"><span class="pstrut" style="height:3.05em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight coloredeq eqda" style=""><span class="mord mathnormal mtight" style="margin-right:0.05724em">j</span></span></span></span><span style="top:-3.0500049999999996em;"><span class="pstrut" style="height:3.05em;"></span><span><span class="mop op-symbol large-op">∑</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:1.413777em;"><span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqbt" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.13889em">P</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span><span class="mord coloredeq eqck" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcx" style="margin-right:0.22222em">V</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span></span><span style="top:0.9137284999999997em;"><span class="pstrut" style="height:3.518331em;"></span><span class="mord"><span class="mord"></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.32144em;"><span style="top:-2.3139999999999996em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord coloredeq eqch" style=""><span class="mord" style=""><span class="mord mathnormal" style="">L</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.677em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord coloredeq eqcs" style=""><span class="mord" style="">1</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.8360000000000001em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mop op-limits"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.050005em;"><span style="top:-1.8723309999999997em;margin-left:0em;"><span class="pstrut" style="height:3.05em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight coloredeq eqda" style=""><span class="mord mathnormal mtight" style="margin-right:0.05724em">j</span></span></span></span><span style="top:-3.0500049999999996em;"><span class="pstrut" style="height:3.05em;"></span><span><span class="mop op-symbol large-op">∑</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:1.413777em;"><span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord coloredeq eqcy" style=""><span class="mord mathnormal" style="">e</span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.891331em;"><span style="top:-3.113em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqbu" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.05764em">S</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3280857142857143em;"><span style="top:-2.357em;margin-left:-0.05764em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2818857142857143em;"><span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span><span class="mord coloredeq eqck" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcx" style="margin-right:0.22222em">V</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:6.1458365em;"><span></span></span></span></span></span></span></span></span></span></span></span></span><p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.969438em;vertical-align:-0.286108em;"></span><span class="mord coloredeq eqbu" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.05764em">S</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.05764em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span></span></span></span> is the attention score matrix before softmax, <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.83333em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqch" style=""><span class="mord" style=""><span class="mord mathnormal" style="">L</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> is the softmax denominator, and <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.969438em;vertical-align:-0.286108em;"></span><span class="mord coloredeq eqbt" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.13889em">P</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span></span></span></span> is the attention matrix after softmax.</p>
<h4>Flash Attention Optimization</h4>
<p>You can compute <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.83333em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqci" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.02778em">O</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span>, instead of doing the full softmax, by computing the sum of exponents <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.84444em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqcl" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.01968em">l</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.01968em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> and the unnormalized output <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.0701899999999998em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqbo" style=""><span class="mord" style=""><span class="mord accent" style=""><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.9201899999999998em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord mathnormal" style="margin-right:0.02778em">O</span></span><span style="top:-3.6023300000000003em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.16666em;"><span class="mord" style="">~</span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> while iterating over keys:</p>
<span ><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:4.706082em;vertical-align:-2.103041em;"></span><span class="mord"><span class="mtable"><span class="col-align-r"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:2.603041em;"><span style="top:-4.68848em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord coloredeq eqbu" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.05764em">S</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.05764em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span></span><span style="top:-3.137149em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord coloredeq eqcl" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.01968em">l</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.01968em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span><span style="top:-1.556959em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord coloredeq eqbo" style=""><span class="mord" style=""><span class="mord accent" style=""><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.9201899999999998em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord mathnormal" style="margin-right:0.02778em">O</span></span><span style="top:-3.6023300000000003em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.16666em;"><span class="mord" style="">~</span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:2.103041em;"><span></span></span></span></span></span><span class="col-align-l"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:2.603041em;"><span style="top:-4.68848em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord"></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mord coloredeq eqbv" style=""><span class="mord mathnormal" style="margin-right:0.03588em">σ</span></span><span class="mord coloredeq eqcj" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcw" style="">Q</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mord coloredeq eqbx" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord coloredeq eqcg" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcv" style="margin-right:0.07153em">K</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.914561em;"><span style="top:-3.1362300000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.13889em">T</span></span></span></span></span></span></span></span></span></span></span><span style="top:-3.137149em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord"></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">←</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mord coloredeq eqcl" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.01968em">l</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.01968em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord"><span class="mord coloredeq eqcy" style=""><span class="mord mathnormal" style="">e</span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.891331em;"><span style="top:-3.113em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqbu" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.05764em">S</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3280857142857143em;"><span style="top:-2.357em;margin-left:-0.05764em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2818857142857143em;"><span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span><span style="top:-1.556959em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord"></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">←</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mord coloredeq eqbo" style=""><span class="mord" style=""><span class="mord accent" style=""><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.9201899999999998em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord mathnormal" style="margin-right:0.02778em">O</span></span><span style="top:-3.6023300000000003em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.16666em;"><span class="mord" style="">~</span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord"><span class="mord coloredeq eqcy" style=""><span class="mord mathnormal" style="">e</span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.891331em;"><span style="top:-3.113em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqbu" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.05764em">S</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3280857142857143em;"><span style="top:-2.357em;margin-left:-0.05764em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2818857142857143em;"><span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span><span class="mord"><span class="mord mathnormal">o</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight coloredeq eqda" style=""><span class="mord mathnormal mtight" style="margin-right:0.05724em">j</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:2.103041em;"><span></span></span></span></span></span></span></span></span></span></span></span></span><p>Finally you can compute,</p>
<p><span ><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:2.4331899999999997em;vertical-align:-0.8360000000000001em;"></span><span class="mord coloredeq eqy" style=""><span class="mord" style=""><span class="mord coloredeq eqci" style=""><span class="mord mathnormal" style="margin-right:0.02778em">O</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel" style="">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mord" style=""><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.5971899999999999em;"><span style="top:-2.3139999999999996em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style=""><span class="mord" style=""><span class="mord coloredeq eqcl" style=""><span class="mord mathnormal" style="margin-right:0.01968em">l</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.01968em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em"></span></span><span style="top:-3.677em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style=""><span class="mord" style=""><span class="mord coloredeq eqbo" style=""><span class="mord accent" style=""><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.9201899999999998em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord mathnormal" style="margin-right:0.02778em">O</span></span><span style="top:-3.6023300000000003em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.16666em;"><span class="mord" style="">~</span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.8360000000000001em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span></span></span></span></span></p>
<p>To make it numerically stable flash attention subtracts the current max of <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.969438em;vertical-align:-0.286108em;"></span><span class="mord coloredeq eqbu" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.05764em">S</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.05764em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span></span></span></span> before exponentiating.</p>
<p>So it maintains the following while iterating over keys:</p>
<ul><li><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqcm" style=""><span class="mord" style=""><span class="mord mathnormal" style="">m</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span>, the max <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.969438em;vertical-align:-0.286108em;"></span><span class="mord coloredeq eqbu" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.05764em">S</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.05764em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span></span></span></span> </li>
<li><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.84444em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqcl" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.01968em">l</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.01968em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span>, the sum of exponents <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.277149em;vertical-align:-0.43581800000000004em;"></span><span class="mop"><span class="mop op-symbol small-op" style="position:relative;top:-0.0000050000000000050004em;">∑</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.16195399999999993em;"><span style="top:-2.40029em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight coloredeq eqda" style=""><span class="mord mathnormal mtight" style="margin-right:0.05724em">j</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.43581800000000004em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord coloredeq eqcy" style=""><span class="mord mathnormal" style="">e</span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.841331em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqbu" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.05764em">S</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3280857142857143em;"><span style="top:-2.357em;margin-left:-0.05764em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2818857142857143em;"><span></span></span></span></span></span></span></span><span class="mbin mtight">−</span><span class="mord mtight coloredeq eqcm" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">m</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3280857142857143em;"><span style="top:-2.357em;margin-left:0em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.143em;"><span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span>, and </li>
<li><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.0701899999999998em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqbo" style=""><span class="mord" style=""><span class="mord accent" style=""><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.9201899999999998em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord mathnormal" style="margin-right:0.02778em">O</span></span><span style="top:-3.6023300000000003em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.16666em;"><span class="mord" style="">~</span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span>, the unnormalized output</li></ul>
<p>For each block of keys <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.85396em;vertical-align:-0.19444em;"></span><span class="mord"><span class="mord coloredeq eqda" style=""><span class="mord mathnormal" style="margin-right:0.05724em">j</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight coloredeq eqcs" style=""><span class="mord mtight" style="">1</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="minner">…</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord coloredeq eqda" style=""><span class="mord mathnormal" style="margin-right:0.05724em">j</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight coloredeq eqct" style=""><span class="mord mtight" style="">2</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span> it updates them:</p>
<span ><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:10.645828000000002em;vertical-align:-5.072914em;"></span><span class="mord"><span class="mtable"><span class="col-align-r"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:5.572914000000001em;"><span style="top:-8.192359000000002em;"><span class="pstrut" style="height:3.858777000000001em;"></span><span class="mord"><span class="mord"><span class="mord coloredeq eqcm" style=""><span class="mord" style=""><span class="mord mathnormal" style="">m</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.7143919999999999em;"><span style="top:-3.113em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">n</span><span class="mord mtight coloredeq eqcy" style=""><span class="mord mtight" style="">e</span></span><span class="mord mtight">w</span></span></span></span></span></span></span></span></span></span></span></span><span style="top:-6.108397000000002em;"><span class="pstrut" style="height:3.858777000000001em;"></span><span class="mord"><span class="mord"><span class="mord accent"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.9201899999999998em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord mathnormal" style="margin-right:0.13889em;">P</span></span><span style="top:-3.6023300000000003em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.16666em;"><span class="mord">~</span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqcz" style=""><span class="mord mathnormal mtight" style="">i</span></span><span class="mord mtight coloredeq eqda" style=""><span class="mord mathnormal mtight" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span><span style="top:-3.5896200000000014em;"><span class="pstrut" style="height:3.858777000000001em;"></span><span class="mord"><span class="mord coloredeq eqcl" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.01968em">l</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.01968em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span><span style="top:-0.945863000000001em;"><span class="pstrut" style="height:3.858777000000001em;"></span><span class="mord"><span class="mord coloredeq eqbo" style=""><span class="mord" style=""><span class="mord accent" style=""><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.9201899999999998em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord mathnormal" style="margin-right:0.02778em">O</span></span><span style="top:-3.6023300000000003em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.16666em;"><span class="mord" style="">~</span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span><span style="top:0.554136999999999em;"><span class="pstrut" style="height:3.858777000000001em;"></span><span class="mord"><span class="mord coloredeq eqcm" style=""><span class="mord" style=""><span class="mord mathnormal" style="">m</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:5.072914em;"><span></span></span></span></span></span><span class="col-align-l"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:5.572914000000001em;"><span style="top:-8.192359000000002em;"><span class="pstrut" style="height:3.858777000000001em;"></span><span class="mord"><span class="mord"></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mop">max</span><span class="mopen">(</span><span class="mord coloredeq eqcm" style=""><span class="mord" style=""><span class="mord mathnormal" style="">m</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mop op-limits"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.239332em;"><span style="top:-2.3723360000000002em;margin-left:0em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqda" style=""><span class="mord mathnormal mtight" style="margin-right:0.05724em">j</span></span><span class="mrel mtight">=</span><span class="mord mtight coloredeq eqda" style=""><span class="mord mathnormal mtight" style="margin-right:0.05724em">j</span></span><span class="mord mtight coloredeq eqcs" style=""><span class="mord mtight" style="">1</span></span></span></span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span><span class="mop">max</span></span></span><span style="top:-3.677668em;margin-left:0em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqda" style=""><span class="mord mathnormal mtight" style="margin-right:0.05724em">j</span></span><span class="mord mtight coloredeq eqct" style=""><span class="mord mtight" style="">2</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.863772em;"><span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqbu" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.05764em">S</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.05764em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span><span class="mclose">)</span></span></span><span style="top:-6.108397000000002em;"><span class="pstrut" style="height:3.858777000000001em;"></span><span class="mord"><span class="mord"></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mop">exp</span><span class="mopen">(</span><span class="mord coloredeq eqbu" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.05764em">S</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.05764em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord"><span class="mord coloredeq eqcm" style=""><span class="mord" style=""><span class="mord mathnormal" style="">m</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.7143919999999999em;"><span style="top:-3.113em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">n</span><span class="mord mtight coloredeq eqcy" style=""><span class="mord mtight" style="">e</span></span><span class="mord mtight">w</span></span></span></span></span></span></span></span></span></span><span class="mclose">)</span></span></span><span style="top:-3.5896200000000014em;"><span class="pstrut" style="height:3.858777000000001em;"></span><span class="mord"><span class="mord"></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">←</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mord coloredeq eqz" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcy" style="">e</span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.9299799999999999em;"><span style="top:-3.1130000000000004em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqcm" style=""><span class="mord mathnormal mtight" style="">m</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3280857142857143em;"><span style="top:-2.357em;margin-left:0em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.143em;"><span></span></span></span></span></span></span></span><span class="mbin mtight" style="">−</span><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">m</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.7385428571428572em;"><span style="top:-2.214em;margin-left:0em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span style="top:-2.931em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord text mtight" style=""><span class="mord mtight" style="">n</span><span class="mord mtight" style=""><span class="mord mtight coloredeq eqcy" style="">e</span></span><span class="mord mtight" style="">w</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286em;"><span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span><span class="mord coloredeq eqcl" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.01968em">l</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.01968em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord coloredeq eqx" style=""><span class="mop op-limits" style=""><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.8587770000000006em;"><span style="top:-1.872331em;margin-left:0em;"><span class="pstrut" style="height:3.05em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span><span class="mrel mtight" style="">=</span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span><span class="mord mtight" style=""><span class="mord mtight coloredeq eqcs" style="">1</span></span></span></span></span><span style="top:-3.050005em;"><span class="pstrut" style="height:3.05em;"></span><span><span class="mop op-symbol large-op" style="">∑</span></span></span><span style="top:-4.347113em;margin-left:0em;"><span class="pstrut" style="height:3.05em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span><span class="mord mtight" style=""><span class="mord mtight coloredeq eqct" style="">2</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:1.4137769999999998em;"><span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord accent" style=""><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.9201899999999998em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord mathnormal" style="margin-right:0.13889em">P</span></span><span style="top:-3.6023300000000003em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.16666em;"><span class="mord" style="">~</span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span></span><span style="top:-0.945863000000001em;"><span class="pstrut" style="height:3.858777000000001em;"></span><span class="mord"><span class="mord"></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">←</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mord coloredeq eqz" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcy" style="">e</span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.9299799999999999em;"><span style="top:-3.1130000000000004em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqcm" style=""><span class="mord mathnormal mtight" style="">m</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3280857142857143em;"><span style="top:-2.357em;margin-left:0em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.143em;"><span></span></span></span></span></span></span></span><span class="mbin mtight" style="">−</span><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">m</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.7385428571428572em;"><span style="top:-2.214em;margin-left:0em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span style="top:-2.931em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord text mtight" style=""><span class="mord mtight" style="">n</span><span class="mord mtight" style=""><span class="mord mtight coloredeq eqcy" style="">e</span></span><span class="mord mtight" style="">w</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286em;"><span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span><span class="mord coloredeq eqbo" style=""><span class="mord" style=""><span class="mord accent" style=""><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.9201899999999998em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord mathnormal" style="margin-right:0.02778em">O</span></span><span style="top:-3.6023300000000003em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.16666em;"><span class="mord" style="">~</span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord"><span class="mord accent"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.9201899999999998em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord mathnormal" style="margin-right:0.13889em;">P</span></span><span style="top:-3.6023300000000003em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.16666em;"><span class="mord">~</span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqcz" style=""><span class="mord mathnormal mtight" style="">i</span></span><span class="mord mtight coloredeq eqda" style=""><span class="mord mathnormal mtight" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">∗</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord coloredeq eqck" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcx" style="margin-right:0.22222em">V</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span></span><span style="top:0.554136999999999em;"><span class="pstrut" style="height:3.858777000000001em;"></span><span class="mord"><span class="mord"></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">←</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mord"><span class="mord mathnormal">m</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.7143919999999999em;"><span style="top:-2.4530000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqcz" style=""><span class="mord mathnormal mtight" style="">i</span></span></span></span></span><span style="top:-3.113em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">n</span><span class="mord mtight coloredeq eqcy" style=""><span class="mord mtight" style="">e</span></span><span class="mord mtight">w</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.247em;"><span></span></span></span></span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:5.072914em;"><span></span></span></span></span></span></span></span></span></span></span></span></span><p>Then finally,</p>
<p><span ><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:2.4331899999999997em;vertical-align:-0.8360000000000001em;"></span><span class="mord coloredeq eqy" style=""><span class="mord" style=""><span class="mord coloredeq eqci" style=""><span class="mord mathnormal" style="margin-right:0.02778em">O</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel" style="">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mord" style=""><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.5971899999999999em;"><span style="top:-2.3139999999999996em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style=""><span class="mord" style=""><span class="mord coloredeq eqcl" style=""><span class="mord mathnormal" style="margin-right:0.01968em">l</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.01968em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em"></span></span><span style="top:-3.677em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style=""><span class="mord" style=""><span class="mord coloredeq eqbo" style=""><span class="mord accent" style=""><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.9201899999999998em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord mathnormal" style="margin-right:0.02778em">O</span></span><span style="top:-3.6023300000000003em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.16666em;"><span class="mord" style="">~</span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.8360000000000001em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span></span></span></span></span></p>
<p>This reduces the memory usage since we don&#x27;t have to compute full <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.969438em;vertical-align:-0.286108em;"></span><span class="mord coloredeq eqbu" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.05764em">S</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.05764em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span></span></span></span> matrix or <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.969438em;vertical-align:-0.286108em;"></span><span class="mord coloredeq eqbt" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.13889em">P</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span></span></span></span> matrix. It also speeds up since we don&#x27;t have to load these large matrices. Instead it only loads blocks of <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqcv" style=""><span class="mord mathnormal" style="margin-right:0.07153em">K</span></span></span></span></span></span> and <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqcx" style=""><span class="mord mathnormal" style="margin-right:0.22222em">V</span></span></span></span></span></span> as it iterates over them.</p>
<h2>Backward pass</h2>
<p>Here&#x27;s the standard backward pass. <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.84444em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqcc" style=""><span class="mord mathnormal" style="">d</span><span class="mord" style=""><span class="mord coloredeq eqci" style=""><span class="mord mathnormal" style="margin-right:0.02778em">O</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span></span> is the gradient vector on the output <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.83333em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqci" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.02778em">O</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span></p>
<span ><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:15.645819000000001em;vertical-align:-7.5729095em;"></span><span class="mord"><span class="mtable"><span class="col-align-r"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:8.072909500000002em;"><span style="top:-10.072909500000002em;"><span class="pstrut" style="height:3.0500050000000005em;"></span><span class="mord"><span class="mord coloredeq eqce" style=""><span class="mord mathnormal" style="">d</span><span class="mord" style=""><span class="mord coloredeq eqck" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcx" style="margin-right:0.22222em">V</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span></span></span><span style="top:-7.580679500000002em;"><span class="pstrut" style="height:3.0500050000000005em;"></span><span class="mord"><span class="mord mathnormal">d</span><span class="mord coloredeq eqbt" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.13889em">P</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span></span><span style="top:-6.080679500000002em;"><span class="pstrut" style="height:3.0500050000000005em;"></span><span class="mord"><span class="mord mathnormal">d</span><span class="mord coloredeq eqbu" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.05764em">S</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.05764em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span></span><span style="top:-4.370674500000002em;"><span class="pstrut" style="height:3.0500050000000005em;"></span><span class="mord"></span></span><span style="top:-1.7185615000000016em;"><span class="pstrut" style="height:3.0500050000000005em;"></span><span class="mord"></span></span><span style="top:0.18145349999999838em;"><span class="pstrut" style="height:3.0500050000000005em;"></span><span class="mord"><span class="mord coloredeq eqcd" style=""><span class="mord mathnormal" style="">d</span><span class="mord" style=""><span class="mord coloredeq eqcj" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcw" style="">Q</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span><span style="top:2.945235499999999em;"><span class="pstrut" style="height:3.0500050000000005em;"></span><span class="mord"><span class="mord coloredeq eqcb" style=""><span class="mord mathnormal" style="">d</span><span class="mord" style=""><span class="mord coloredeq eqcg" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcv" style="margin-right:0.07153em">K</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:7.5729095em;"><span></span></span></span></span></span><span class="col-align-l"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:8.072909500000002em;"><span style="top:-10.072909500000002em;"><span class="pstrut" style="height:3.0500050000000005em;"></span><span class="mord"><span class="mord"></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mop op-limits"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.0500050000000003em;"><span style="top:-1.872331em;margin-left:0em;"><span class="pstrut" style="height:3.05em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight coloredeq eqcz" style=""><span class="mord mathnormal mtight" style="">i</span></span></span></span><span style="top:-3.050005em;"><span class="pstrut" style="height:3.05em;"></span><span><span class="mop op-symbol large-op">∑</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:1.277669em;"><span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqbt" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.13889em">P</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span><span class="mord coloredeq eqcc" style=""><span class="mord mathnormal" style="">d</span><span class="mord" style=""><span class="mord coloredeq eqci" style=""><span class="mord mathnormal" style="margin-right:0.02778em">O</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span><span style="top:-7.580679500000002em;"><span class="pstrut" style="height:3.0500050000000005em;"></span><span class="mord"><span class="mord"></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mord mathnormal">d</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.02778em;">O</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqcz" style=""><span class="mord mathnormal mtight" style="">i</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mord coloredeq eqbz" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord coloredeq eqck" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcx" style="margin-right:0.22222em">V</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.914561em;"><span style="top:-3.1362300000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.13889em">T</span></span></span></span></span></span></span></span></span></span></span><span style="top:-6.080679500000002em;"><span class="pstrut" style="height:3.0500050000000005em;"></span><span class="mord"><span class="mord"></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mord mathnormal">d</span><span class="mord text"><span class="mord">softmax</span></span><span class="mopen">(</span><span class="mord mathnormal">d</span><span class="mord coloredeq eqbt" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.13889em">P</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span><span class="mclose">)</span></span></span><span style="top:-4.370674500000002em;"><span class="pstrut" style="height:3.0500050000000005em;"></span><span class="mord"><span class="mord"></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mop op-limits"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.0500050000000005em;"><span style="top:-1.847887em;margin-left:0em;"><span class="pstrut" style="height:3.05em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.03148em;">k</span></span></span><span style="top:-3.050005em;"><span class="pstrut" style="height:3.05em;"></span><span><span class="mop op-symbol large-op">∑</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:1.3021129999999999em;"><span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.13889em;">P</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqcz" style=""><span class="mord mathnormal mtight" style="">i</span></span><span class="mord mathnormal mtight" style="margin-right:0.03148em;">k</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord coloredeq eqbn" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.03785em">δ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361079999999999em;"><span style="top:-2.5500000000000003em;margin-left:-0.03785em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span><span class="mord mathnormal mtight" style="margin-right:0.03148em">k</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord coloredeq eqbt" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.13889em">P</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span><span class="mclose">)</span><span class="mord mathnormal">d</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.13889em;">P</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqcz" style=""><span class="mord mathnormal mtight" style="">i</span></span><span class="mord mathnormal mtight" style="margin-right:0.03148em;">k</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span><span style="top:-1.7185615000000016em;"><span class="pstrut" style="height:3.0500050000000005em;"></span><span class="mord"><span class="mord"></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mord coloredeq eqbt" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.13889em">P</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span><span class="mord mathnormal">d</span><span class="mord coloredeq eqbt" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.13889em">P</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord coloredeq eqbt" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.13889em">P</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mop op-symbol large-op" style="position:relative;top:-0.000004999999999977245em;">∑</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.13889em;">P</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqcz" style=""><span class="mord mathnormal mtight" style="">i</span></span><span class="mord mathnormal mtight" style="margin-right:0.03148em;">k</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mord mathnormal">d</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.13889em;">P</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqcz" style=""><span class="mord mathnormal mtight" style="">i</span></span><span class="mord mathnormal mtight" style="margin-right:0.03148em;">k</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span><span style="top:0.18145349999999838em;"><span class="pstrut" style="height:3.0500050000000005em;"></span><span class="mord"><span class="mord"></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mord coloredeq eqbv" style=""><span class="mord mathnormal" style="margin-right:0.03588em">σ</span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mop op-limits"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.050005em;"><span style="top:-1.8723309999999997em;margin-left:0em;"><span class="pstrut" style="height:3.05em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight coloredeq eqda" style=""><span class="mord mathnormal mtight" style="margin-right:0.05724em">j</span></span></span></span><span style="top:-3.0500049999999996em;"><span class="pstrut" style="height:3.05em;"></span><span><span class="mop op-symbol large-op">∑</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:1.413777em;"><span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord mathnormal">d</span><span class="mord coloredeq eqbu" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.05764em">S</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.05764em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span><span class="mord coloredeq eqcg" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcv" style="margin-right:0.07153em">K</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span></span><span style="top:2.945235499999999em;"><span class="pstrut" style="height:3.0500050000000005em;"></span><span class="mord"><span class="mord"></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mord coloredeq eqbv" style=""><span class="mord mathnormal" style="margin-right:0.03588em">σ</span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mop op-limits"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.0500050000000003em;"><span style="top:-1.872331em;margin-left:0em;"><span class="pstrut" style="height:3.05em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight coloredeq eqcz" style=""><span class="mord mathnormal mtight" style="">i</span></span></span></span><span style="top:-3.050005em;"><span class="pstrut" style="height:3.05em;"></span><span><span class="mop op-symbol large-op">∑</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:1.277669em;"><span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord mathnormal">d</span><span class="mord coloredeq eqbu" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.05764em">S</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.05764em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span><span class="mord coloredeq eqcj" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcw" style="">Q</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:7.5729095em;"><span></span></span></span></span></span></span></span></span></span></span></span></span><p>where <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.980548em;vertical-align:-0.286108em;"></span><span class="mord coloredeq eqbn" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.03785em">δ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361079999999999em;"><span style="top:-2.5500000000000003em;margin-left:-0.03785em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span><span class="mord mathnormal mtight" style="margin-right:0.03148em">k</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span></span></span></span> is <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqcs" style=""><span class="mord" style="">1</span></span></span></span></span></span> when <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.85396em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqda" style=""><span class="mord mathnormal" style="margin-right:0.05724em">j</span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord mathnormal" style="margin-right:0.03148em;">k</span></span></span></span></span> and <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqcr" style=""><span class="mord" style="">0</span></span></span></span></span></span> otherwise.</p>
<p>Flash attention paper introduces <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.83333em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqcf" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcu" style="margin-right:0.02778em">D</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> to simplify <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord mathnormal">d</span><span class="mord mathnormal" style="margin-right:0.05764em;">S</span></span></span></span></span> computation.</p>
<span ><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:9.530915000000002em;vertical-align:-4.515457500000001em;"></span><span class="mord"><span class="mtable"><span class="col-align-r"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:5.015457500000001em;"><span style="top:-7.015457500000001em;"><span class="pstrut" style="height:3.0500050000000005em;"></span><span class="mord"><span class="mord coloredeq eqcf" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcu" style="margin-right:0.02778em">D</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span><span style="top:-4.3633395em;"><span class="pstrut" style="height:3.0500050000000005em;"></span><span class="mord"></span></span><span style="top:-1.7112214999999995em;"><span class="pstrut" style="height:3.0500050000000005em;"></span><span class="mord"></span></span><span style="top:0.8054525000000003em;"><span class="pstrut" style="height:3.0500050000000005em;"></span><span class="mord"></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:4.515457500000001em;"><span></span></span></span></span></span><span class="col-align-l"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:5.015457500000001em;"><span style="top:-7.015457500000001em;"><span class="pstrut" style="height:3.0500050000000005em;"></span><span class="mord"><span class="mord"></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mop op-limits"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.0500050000000005em;"><span style="top:-1.847887em;margin-left:0em;"><span class="pstrut" style="height:3.05em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.03148em;">k</span></span></span><span style="top:-3.050005em;"><span class="pstrut" style="height:3.05em;"></span><span><span class="mop op-symbol large-op">∑</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:1.3021129999999999em;"><span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.13889em;">P</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqcz" style=""><span class="mord mathnormal mtight" style="">i</span></span><span class="mord mathnormal mtight" style="margin-right:0.03148em;">k</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mord mathnormal">d</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.13889em;">P</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqcz" style=""><span class="mord mathnormal mtight" style="">i</span></span><span class="mord mathnormal mtight" style="margin-right:0.03148em;">k</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span><span style="top:-4.3633395em;"><span class="pstrut" style="height:3.0500050000000005em;"></span><span class="mord"><span class="mord"></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mop op-limits"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.0500050000000005em;"><span style="top:-1.847887em;margin-left:0em;"><span class="pstrut" style="height:3.05em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.03148em;">k</span></span></span><span style="top:-3.050005em;"><span class="pstrut" style="height:3.05em;"></span><span><span class="mop op-symbol large-op">∑</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:1.3021129999999999em;"><span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.13889em;">P</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqcz" style=""><span class="mord mathnormal mtight" style="">i</span></span><span class="mord mathnormal mtight" style="margin-right:0.03148em;">k</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mord coloredeq eqcc" style=""><span class="mord mathnormal" style="">d</span><span class="mord" style=""><span class="mord coloredeq eqci" style=""><span class="mord mathnormal" style="margin-right:0.02778em">O</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span><span class="mord"><span class="mord coloredeq eqcx" style=""><span class="mord mathnormal" style="margin-right:0.22222em">V</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8913309999999999em;"><span style="top:-2.4530000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.03148em;">k</span></span></span><span style="top:-3.113em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.13889em;">T</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.247em;"><span></span></span></span></span></span></span></span></span><span style="top:-1.7112214999999995em;"><span class="pstrut" style="height:3.0500050000000005em;"></span><span class="mord"><span class="mord"></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mord coloredeq eqcc" style=""><span class="mord mathnormal" style="">d</span><span class="mord" style=""><span class="mord coloredeq eqci" style=""><span class="mord mathnormal" style="margin-right:0.02778em">O</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mop op-limits"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.0500050000000005em;"><span style="top:-1.847887em;margin-left:0em;"><span class="pstrut" style="height:3.05em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.03148em;">k</span></span></span><span style="top:-3.050005em;"><span class="pstrut" style="height:3.05em;"></span><span><span class="mop op-symbol large-op">∑</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:1.3021129999999999em;"><span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.13889em;">P</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqcz" style=""><span class="mord mathnormal mtight" style="">i</span></span><span class="mord mathnormal mtight" style="margin-right:0.03148em;">k</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mord"><span class="mord coloredeq eqcx" style=""><span class="mord mathnormal" style="margin-right:0.22222em">V</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8913309999999999em;"><span style="top:-2.4530000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.03148em;">k</span></span></span><span style="top:-3.113em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.13889em;">T</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.247em;"><span></span></span></span></span></span></span></span></span><span style="top:0.8054525000000003em;"><span class="pstrut" style="height:3.0500050000000005em;"></span><span class="mord"><span class="mord"></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mord coloredeq eqcc" style=""><span class="mord mathnormal" style="">d</span><span class="mord" style=""><span class="mord coloredeq eqci" style=""><span class="mord mathnormal" style="margin-right:0.02778em">O</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span><span class="mord"><span class="mord coloredeq eqci" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.02778em">O</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.914561em;"><span style="top:-3.1362300000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.13889em;">T</span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:4.515457500000001em;"><span></span></span></span></span></span></span></span></span></span></span></span></span><p>Then,</p>
<span ><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.5000000000000002em;vertical-align:-0.5000000000000002em;"></span><span class="mord"><span class="mtable"><span class="col-align-r"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1em;"><span style="top:-3.16em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord mathnormal">d</span><span class="mord coloredeq eqbu" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.05764em">S</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.05764em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mord coloredeq eqbt" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.13889em">P</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span><span class="mord mathnormal">d</span><span class="mord coloredeq eqbt" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.13889em">P</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord coloredeq eqcf" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcu" style="margin-right:0.02778em">D</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mord coloredeq eqbt" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.13889em">P</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.5000000000000002em;"><span></span></span></span></span></span></span></span></span></span></span></span></span><p>Flash attention saves <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.83333em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqch" style=""><span class="mord" style=""><span class="mord mathnormal" style="">L</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> from the forward pass since it doesn&#x27;t take much memory. So during the backward pass it doesn&#x27;t have to keep computing <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.84444em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqcl" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.01968em">l</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.01968em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> or <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqcm" style=""><span class="mord" style=""><span class="mord mathnormal" style="">m</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span>.</p>
<p>It first computes <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.83333em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqcf" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcu" style="margin-right:0.02778em">D</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span>. Then it iterates over the queries and compute (accumulate) <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.980548em;vertical-align:-0.286108em;"></span><span class="mord coloredeq eqcb" style=""><span class="mord mathnormal" style="">d</span><span class="mord" style=""><span class="mord coloredeq eqcg" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcv" style="margin-right:0.07153em">K</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span></span></span></span></span> and <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.980548em;vertical-align:-0.286108em;"></span><span class="mord coloredeq eqce" style=""><span class="mord mathnormal" style="">d</span><span class="mord" style=""><span class="mord coloredeq eqck" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcx" style="margin-right:0.22222em">V</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span></span></span></span></span>. Finally it iterates over the keys and compute (accumulate) <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqcd" style=""><span class="mord mathnormal" style="">d</span><span class="mord" style=""><span class="mord coloredeq eqcj" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcw" style="">Q</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span></span>.</p>
<p>In both forward and backward pass we calculate logarithms and exponentials of <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqct" style=""><span class="mord" style="">2</span></span></span></span></span></span> instead of <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqcy" style=""><span class="mord mathnormal" style="">e</span></span></span></span></span></span> for performance.</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">148</span><span></span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Any</span><span class="p">,</span> <span class="n">Tuple</span>
<span class="lineno">149</span>
<span class="lineno">150</span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">151</span><span class="kn">import</span> <span class="nn">triton</span>
<span class="lineno">152</span><span class="kn">import</span> <span class="nn">triton.language</span> <span class="k">as</span> <span class="nn">tl</span>
<span class="lineno">153</span>
<span class="lineno">154</span><span class="n">HI_PRES_TL</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">float32</span>
<span class="lineno">155</span><span class="n">HI_PRES_TORCH</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">dtype</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-1'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-1'>#</a>
            </div>
            
        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">158</span><span class="k">class</span> <span class="nc">AttentionFunc</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">autograd</span><span class="o">.</span><span class="n">Function</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-2'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-2'>#</a>
            </div>
            <h3>Forward pass</h3>
<p>Group query attention forward pass. Returns the output in shape <code  class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">,</span> <span class="n">q_seq_len</span><span class="p">,</span> <span class="n">d_head</span><span class="p">]</span></code>
.</p>
<ul><li><code  class="highlight"><span></span><span class="n">ctx</span></code>
  is the context for torch gradient descent </li>
<li><code  class="highlight"><span></span><span class="n">q</span></code>
  has shape <code  class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">,</span> <span class="n">q_seq_len</span><span class="p">,</span> <span class="n">d_head</span><span class="p">]</span></code>
 </li>
<li><code  class="highlight"><span></span><span class="n">q</span></code>
  has shape <code  class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">,</span> <span class="n">q_seq_len</span><span class="p">,</span> <span class="n">d_head</span><span class="p">]</span></code>
 </li>
<li><code  class="highlight"><span></span><span class="n">k</span></code>
  has shape <code  class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">k_heads</span><span class="p">,</span> <span class="n">kv_seq_len</span><span class="p">,</span> <span class="n">d_head</span><span class="p">]</span></code>
 </li>
<li><code  class="highlight"><span></span><span class="n">v</span></code>
  has shape <code  class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">k_heads</span><span class="p">,</span> <span class="n">kv_seq_len</span><span class="p">,</span> <span class="n">d_head</span><span class="p">]</span></code>
 </li>
<li><code  class="highlight"><span></span><span class="n">causal</span></code>
  whether to apply causal attention mask </li>
<li><code  class="highlight"><span></span><span class="n">sm_scale</span></code>
  softmax scale factor <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqbv" style=""><span class="mord mathnormal" style="margin-right:0.03588em">σ</span></span></span></span></span></span></li></ul>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">159</span>    <span class="nd">@staticmethod</span>
<span class="lineno">160</span>    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="n">ctx</span><span class="p">:</span> <span class="n">Any</span><span class="p">,</span>
<span class="lineno">161</span>                <span class="n">q</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">k</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">v</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="lineno">162</span>                <span class="n">causal</span><span class="p">:</span> <span class="nb">bool</span><span class="p">,</span> <span class="n">sm_scale</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-3'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-3'>#</a>
            </div>
            
        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">176</span>        <span class="n">batch_size</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">,</span> <span class="n">q_seq_len</span><span class="p">,</span> <span class="n">d_head</span> <span class="o">=</span> <span class="n">q</span><span class="o">.</span><span class="n">shape</span>
<span class="lineno">177</span>        <span class="n">_</span><span class="p">,</span> <span class="n">k_heads</span><span class="p">,</span> <span class="n">kv_seq_len</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">k</span><span class="o">.</span><span class="n">shape</span>
<span class="lineno">178</span>        <span class="k">assert</span> <span class="n">n_heads</span> <span class="o">%</span> <span class="n">k_heads</span> <span class="o">==</span> <span class="mi">0</span>
<span class="lineno">179</span>        <span class="n">n_groups</span> <span class="o">=</span> <span class="n">n_heads</span> <span class="o">//</span> <span class="n">k_heads</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-4'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-4'>#</a>
            </div>
            <p>Shape constraints </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">182</span>        <span class="k">assert</span> <span class="n">d_head</span> <span class="o">==</span> <span class="n">k</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">==</span> <span class="n">v</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
<span class="lineno">183</span>        <span class="k">assert</span> <span class="n">d_head</span> <span class="ow">in</span> <span class="p">{</span><span class="mi">16</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">64</span><span class="p">,</span> <span class="mi">128</span><span class="p">,</span> <span class="mi">256</span><span class="p">}</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-5'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-5'>#</a>
            </div>
            <p>Change the tensors combining the heads with the batch dimension </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">186</span>        <span class="n">q</span> <span class="o">=</span> <span class="n">q</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span> <span class="o">*</span> <span class="n">k_heads</span><span class="p">,</span> <span class="n">n_groups</span><span class="p">,</span> <span class="n">q_seq_len</span><span class="p">,</span> <span class="n">d_head</span><span class="p">)</span>
<span class="lineno">187</span>        <span class="n">k</span> <span class="o">=</span> <span class="n">k</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span> <span class="o">*</span> <span class="n">k_heads</span><span class="p">,</span> <span class="n">kv_seq_len</span><span class="p">,</span> <span class="n">d_head</span><span class="p">)</span>
<span class="lineno">188</span>        <span class="n">v</span> <span class="o">=</span> <span class="n">v</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span> <span class="o">*</span> <span class="n">k_heads</span><span class="p">,</span> <span class="n">kv_seq_len</span><span class="p">,</span> <span class="n">d_head</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-6'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-6'>#</a>
            </div>
            <p>Make sure the tensors are contiguous and the strides are same </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">191</span>        <span class="k">assert</span> <span class="n">q</span><span class="o">.</span><span class="n">is_contiguous</span><span class="p">()</span>
<span class="lineno">192</span>        <span class="k">assert</span> <span class="n">k</span><span class="o">.</span><span class="n">is_contiguous</span><span class="p">()</span>
<span class="lineno">193</span>        <span class="k">assert</span> <span class="n">v</span><span class="o">.</span><span class="n">is_contiguous</span><span class="p">()</span>
<span class="lineno">194</span>        <span class="k">assert</span> <span class="n">k</span><span class="o">.</span><span class="n">stride</span><span class="p">()</span> <span class="o">==</span> <span class="n">v</span><span class="o">.</span><span class="n">stride</span><span class="p">()</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-7'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-7'>#</a>
            </div>
            <p>Tensor for the output </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">197</span>        <span class="n">o</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">q</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-8'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-8'>#</a>
            </div>
            <p>Tensor for log of sum of exponentials <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.93858em;vertical-align:-0.24414em;"></span><span class="mord coloredeq eqbp" style=""><span class="mop" style=""><span class="mop" style=""><span style="">l</span><span style="">o</span><span style="margin-right:0.01389em">g</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.20696799999999996em;"><span style="top:-2.4558600000000004em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqct" style="">2</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.24414em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord coloredeq eqch" style=""><span class="mord mathnormal" style="">L</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1.277149em;vertical-align:-0.43581800000000004em;"></span><span class="mord coloredeq eqbb" style=""><span class="mop" style=""><span class="mop" style=""><span style="">l</span><span style="">o</span><span style="margin-right:0.01389em">g</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.20696799999999996em;"><span style="top:-2.4558600000000004em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqct" style="">2</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.24414em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mop" style=""><span class="mop op-symbol small-op" style="position:relative;top:-0.0000050000000000050004em">∑</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.16195399999999993em;"><span style="top:-2.40029em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.43581800000000004em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcy" style="">e</span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.841331em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqbu" style=""><span class="mord mathnormal mtight" style="margin-right:0.05764em">S</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3280857142857143em;"><span style="top:-2.357em;margin-left:-0.05764em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2818857142857143em;"><span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">199</span>        <span class="n">lse</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">((</span><span class="n">batch_size</span> <span class="o">*</span> <span class="n">k_heads</span><span class="p">,</span> <span class="n">n_groups</span><span class="p">,</span> <span class="n">q_seq_len</span><span class="p">),</span> <span class="n">device</span><span class="o">=</span><span class="n">q</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">HI_PRES_TORCH</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-9'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-9'>#</a>
            </div>
            <p>The forward computation will be parallelized along the batch dimension and the queries in blocks of size <code  class="highlight"><span></span><span class="n">BLOCK_Q</span></code>
 </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">202</span>        <span class="n">grid</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">meta</span><span class="p">:</span> <span class="p">(</span><span class="n">triton</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">q_seq_len</span><span class="p">,</span> <span class="n">meta</span><span class="p">[</span><span class="s2">&quot;BLOCK_Q&quot;</span><span class="p">]),</span> <span class="n">batch_size</span> <span class="o">*</span> <span class="n">k_heads</span> <span class="o">*</span> <span class="n">n_groups</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="lineno">203</span>        <span class="n">_attn_fwd</span><span class="p">[</span><span class="n">grid</span><span class="p">](</span>
<span class="lineno">204</span>            <span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">,</span> <span class="n">sm_scale</span> <span class="o">*</span> <span class="mf">1.4426950408889634</span><span class="p">,</span> <span class="n">lse</span><span class="p">,</span> <span class="n">o</span><span class="p">,</span>
<span class="lineno">205</span>            <span class="n">n_groups</span><span class="o">=</span><span class="n">n_groups</span><span class="p">,</span>
<span class="lineno">206</span>            <span class="n">q_seq_len</span><span class="o">=</span><span class="n">q_seq_len</span><span class="p">,</span>
<span class="lineno">207</span>            <span class="n">kv_seq_len</span><span class="o">=</span><span class="n">kv_seq_len</span><span class="p">,</span>
<span class="lineno">208</span>            <span class="n">d_head</span><span class="o">=</span><span class="n">d_head</span><span class="p">,</span>
<span class="lineno">209</span>            <span class="n">is_causal</span><span class="o">=</span><span class="n">causal</span><span class="p">,</span>
<span class="lineno">210</span>        <span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-10'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-10'>#</a>
            </div>
            <p>Save the reshaped inputs and outputs for the backward pass </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">213</span>        <span class="n">ctx</span><span class="o">.</span><span class="n">save_for_backward</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">,</span> <span class="n">o</span><span class="p">,</span> <span class="n">lse</span><span class="p">)</span>
<span class="lineno">214</span>        <span class="n">ctx</span><span class="o">.</span><span class="n">sm_scale</span> <span class="o">=</span> <span class="n">sm_scale</span>
<span class="lineno">215</span>        <span class="n">ctx</span><span class="o">.</span><span class="n">n_groups</span> <span class="o">=</span> <span class="n">n_groups</span>
<span class="lineno">216</span>        <span class="n">ctx</span><span class="o">.</span><span class="n">causal</span> <span class="o">=</span> <span class="n">causal</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-11'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-11'>#</a>
            </div>
            <p>Return the output in shape <code  class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">,</span> <span class="n">q_seq_len</span><span class="p">,</span> <span class="n">d_head</span><span class="p">]</span></code>
 </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">219</span>        <span class="k">return</span> <span class="n">o</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">,</span> <span class="n">q_seq_len</span><span class="p">,</span> <span class="n">d_head</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-12'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-12'>#</a>
            </div>
            <h3>Backward pass</h3>
<p>The backward pass computes the gradients of the input tensors.</p>
<ul><li><code  class="highlight"><span></span><span class="n">ctx</span></code>
  is the context for torch gradient descent </li>
<li><code  class="highlight"><span></span><span class="n">do</span></code>
  is the gradient tensor of the attention output with shape <code  class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">,</span> <span class="n">q_seq_len</span><span class="p">,</span> <span class="n">d_head</span><span class="p">]</span></code>
</li></ul>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">221</span>    <span class="nd">@staticmethod</span>
<span class="lineno">222</span>    <span class="k">def</span> <span class="nf">backward</span><span class="p">(</span><span class="n">ctx</span><span class="p">:</span> <span class="n">Any</span><span class="p">,</span> <span class="n">do</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">]:</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-13'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-13'>#</a>
            </div>
            <p>Get saved tensors and attributes </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">233</span>        <span class="n">n_groups</span> <span class="o">=</span> <span class="n">ctx</span><span class="o">.</span><span class="n">n_groups</span>
<span class="lineno">234</span>        <span class="n">sm_scale</span> <span class="o">=</span> <span class="n">ctx</span><span class="o">.</span><span class="n">sm_scale</span>
<span class="lineno">235</span>        <span class="n">causal</span> <span class="o">=</span> <span class="n">ctx</span><span class="o">.</span><span class="n">causal</span>
<span class="lineno">236</span>        <span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">,</span> <span class="n">o</span><span class="p">,</span> <span class="n">lse</span> <span class="o">=</span> <span class="n">ctx</span><span class="o">.</span><span class="n">saved_tensors</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-14'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-14'>#</a>
            </div>
            <p>Get shapes </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">239</span>        <span class="n">batch_size</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">,</span> <span class="n">q_seq_len</span><span class="p">,</span> <span class="n">d_head</span> <span class="o">=</span> <span class="n">do</span><span class="o">.</span><span class="n">shape</span>
<span class="lineno">240</span>        <span class="n">_</span><span class="p">,</span> <span class="n">kv_seq_len</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">k</span><span class="o">.</span><span class="n">shape</span>
<span class="lineno">241</span>        <span class="n">k_heads</span> <span class="o">=</span> <span class="n">n_heads</span> <span class="o">//</span> <span class="n">n_groups</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-15'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-15'>#</a>
            </div>
            <p>Combine the heads with the batch dimension of the output gradients tensor </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">244</span>        <span class="n">do</span> <span class="o">=</span> <span class="n">do</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span> <span class="o">*</span> <span class="n">k_heads</span><span class="p">,</span> <span class="n">n_groups</span><span class="p">,</span> <span class="n">q_seq_len</span><span class="p">,</span> <span class="n">d_head</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-16'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-16'>#</a>
            </div>
            <p>Make sure it&#x27;s contiguous and the strides are the same </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">247</span>        <span class="k">assert</span> <span class="n">do</span><span class="o">.</span><span class="n">is_contiguous</span><span class="p">()</span>
<span class="lineno">248</span>        <span class="k">assert</span> <span class="n">k</span><span class="o">.</span><span class="n">stride</span><span class="p">()</span> <span class="o">==</span> <span class="n">v</span><span class="o">.</span><span class="n">stride</span><span class="p">()</span>
<span class="lineno">249</span>        <span class="k">assert</span> <span class="n">q</span><span class="o">.</span><span class="n">stride</span><span class="p">()</span> <span class="o">==</span> <span class="n">o</span><span class="o">.</span><span class="n">stride</span><span class="p">()</span> <span class="o">==</span> <span class="n">do</span><span class="o">.</span><span class="n">stride</span><span class="p">()</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-17'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-17'>#</a>
            </div>
            <p>Create tensors for input gradients </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">252</span>        <span class="n">dq</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">q</span><span class="p">)</span>
<span class="lineno">253</span>        <span class="n">dk</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">k</span><span class="p">)</span>
<span class="lineno">254</span>        <span class="n">dv</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">v</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-18'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-18'>#</a>
            </div>
            <p>Precompute <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.036108em;vertical-align:-0.286108em;"></span><span class="mord coloredeq eqbe" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqbv" style="margin-right:0.03588em">σ</span></span><span class="mopen" style="">(</span><span class="mord" style=""><span class="mop coloredeq eqbr" style=""><span class="mop" style=""><span style="">l</span><span style="">o</span><span style="margin-right:0.01389em">g</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.20696799999999996em;"><span style="top:-2.4558600000000004em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqct" style="">2</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.24414em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em"></span><span class="mord coloredeq eqbr" style=""><span class="mord mathnormal coloredeq eqcy" style="">e</span></span></span><span class="mclose" style="">)</span><span class="mord" style=""><span class="mord coloredeq eqcg" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcv" style="margin-right:0.07153em">K</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">257</span>        <span class="n">k_scaled</span> <span class="o">=</span> <span class="n">k</span> <span class="o">*</span> <span class="p">(</span><span class="n">sm_scale</span> <span class="o">*</span> <span class="mf">1.4426950408889634</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-19'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-19'>#</a>
            </div>
            <p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.83333em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqcf" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcu" style="margin-right:0.02778em">D</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1.0999949999999998em;vertical-align:-0.258664em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.13889em;">P</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8413309999999999em;"><span style="top:-2.441336em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqcz" style=""><span class="mord mathnormal mtight" style="">i</span></span><span class="mrel mtight">:</span></span></span></span><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.13889em;">T</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.258664em;"><span></span></span></span></span></span></span><span class="mord mathnormal">d</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.13889em;">P</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqcz" style=""><span class="mord mathnormal mtight" style="">i</span></span><span class="mrel mtight">:</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1.0999949999999998em;vertical-align:-0.258664em;"></span><span class="mord mathnormal">d</span><span class="mord"><span class="mord mathnormal">o</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8413309999999999em;"><span style="top:-2.441336em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight coloredeq eqcz" style=""><span class="mord mathnormal mtight" style="">i</span></span></span></span><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.13889em;">T</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.258664em;"><span></span></span></span></span></span></span><span class="mord"><span class="mord mathnormal">o</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight coloredeq eqcz" style=""><span class="mord mathnormal mtight" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">259</span>        <span class="n">pdp</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">lse</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-20'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-20'>#</a>
            </div>
            <p>We use fixed <code  class="highlight"><span></span><span class="n">BLOCK_Q</span></code>
 for backward pass on <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqcu" style=""><span class="mord mathnormal" style="margin-right:0.02778em">D</span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre></pre></div>
        </div>
    </div>
    <div class='section' id='section-21'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-21'>#</a>
            </div>
            <p>Compute <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.83333em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqcf" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcu" style="margin-right:0.02778em">D</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span></p>
<p>This is parallelized along the batch and query in blocks of size <code  class="highlight"><span></span><span class="n">BLOCK_Q</span></code>
 </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">265</span>        <span class="n">BLOCK_Q</span> <span class="o">=</span> <span class="mi">16</span>
<span class="lineno">266</span>        <span class="n">pre_grid</span> <span class="o">=</span> <span class="p">(</span><span class="n">triton</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">q_seq_len</span><span class="p">,</span> <span class="n">BLOCK_Q</span><span class="p">),</span> <span class="n">batch_size</span> <span class="o">*</span> <span class="n">k_heads</span><span class="p">)</span>
<span class="lineno">267</span>        <span class="n">_attn_bwd_d</span><span class="p">[</span><span class="n">pre_grid</span><span class="p">](</span>
<span class="lineno">268</span>            <span class="n">o</span><span class="p">,</span> <span class="n">do</span><span class="p">,</span>
<span class="lineno">269</span>            <span class="n">pdp</span><span class="p">,</span>
<span class="lineno">270</span>            <span class="n">BLOCK_Q</span><span class="o">=</span><span class="mi">16</span><span class="p">,</span>
<span class="lineno">271</span>            <span class="n">d_head</span><span class="o">=</span><span class="n">d_head</span><span class="p">,</span>
<span class="lineno">272</span>            <span class="n">q_seq_len</span><span class="o">=</span><span class="n">q_seq_len</span><span class="p">,</span>
<span class="lineno">273</span>            <span class="n">n_groups</span><span class="o">=</span><span class="n">n_groups</span><span class="p">,</span>
<span class="lineno">274</span>            <span class="n">num_stages</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
<span class="lineno">275</span>        <span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-22'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-22'>#</a>
            </div>
            <p>Compute <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord coloredeq eqcn" style=""><span class="mord mathnormal" style="">d</span><span class="mord" style=""><span class="mord mathnormal coloredeq eqcv" style="margin-right:0.07153em">K</span></span></span></span></span></span></span> and <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord coloredeq eqcq" style=""><span class="mord mathnormal" style="">d</span><span class="mord" style=""><span class="mord mathnormal coloredeq eqcx" style="margin-right:0.22222em">V</span></span></span></span></span></span></span></p>
<p>This is parallelized along the batch and keys in blocks of size <code  class="highlight"><span></span><span class="n">BLOCK_K</span></code>
 </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">280</span>        <span class="n">grid</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">meta</span><span class="p">:</span> <span class="p">(</span><span class="n">triton</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">kv_seq_len</span><span class="p">,</span> <span class="n">meta</span><span class="p">[</span><span class="s1">&#39;BLOCK_K&#39;</span><span class="p">]),</span> <span class="n">batch_size</span> <span class="o">*</span> <span class="n">k_heads</span><span class="p">)</span>
<span class="lineno">281</span>        <span class="n">_attn_bwd_dkdv</span><span class="p">[</span><span class="n">grid</span><span class="p">](</span>
<span class="lineno">282</span>            <span class="n">q</span><span class="p">,</span> <span class="n">k_scaled</span><span class="p">,</span> <span class="n">v</span><span class="p">,</span> <span class="n">sm_scale</span><span class="p">,</span> <span class="n">do</span><span class="p">,</span> <span class="n">dk</span><span class="p">,</span> <span class="n">dv</span><span class="p">,</span>
<span class="lineno">283</span>            <span class="n">lse</span><span class="p">,</span> <span class="n">pdp</span><span class="p">,</span>
<span class="lineno">284</span>            <span class="n">q_seq_len</span><span class="p">,</span> <span class="n">kv_seq_len</span><span class="p">,</span> <span class="n">n_groups</span><span class="p">,</span> <span class="n">d_head</span><span class="p">,</span>
<span class="lineno">285</span>            <span class="n">is_causal</span><span class="o">=</span><span class="n">causal</span><span class="p">,</span>
<span class="lineno">286</span>
<span class="lineno">287</span>        <span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-23'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-23'>#</a>
            </div>
            <p>Compute <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqco" style=""><span class="mord mathnormal" style="">d</span><span class="mord" style=""><span class="mord mathnormal coloredeq eqcw" style="">Q</span></span></span></span></span></span></span></p>
<p>This is parallelized along the batch and queries in blocks of size <code  class="highlight"><span></span><span class="n">BLOCK_Q</span></code>
 </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">292</span>        <span class="n">grid</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">meta</span><span class="p">:</span> <span class="p">(</span><span class="n">triton</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">q_seq_len</span><span class="p">,</span> <span class="n">meta</span><span class="p">[</span><span class="s1">&#39;BLOCK_Q&#39;</span><span class="p">]),</span> <span class="n">batch_size</span> <span class="o">*</span> <span class="n">k_heads</span> <span class="o">*</span> <span class="n">n_groups</span><span class="p">)</span>
<span class="lineno">293</span>        <span class="n">_attn_bwd_dq</span><span class="p">[</span><span class="n">grid</span><span class="p">](</span>
<span class="lineno">294</span>            <span class="n">q</span><span class="p">,</span> <span class="n">k_scaled</span><span class="p">,</span> <span class="n">v</span><span class="p">,</span> <span class="n">do</span><span class="p">,</span>
<span class="lineno">295</span>            <span class="n">dq</span><span class="p">,</span>
<span class="lineno">296</span>            <span class="n">lse</span><span class="p">,</span> <span class="n">pdp</span><span class="p">,</span>
<span class="lineno">297</span>            <span class="n">q_seq_len</span><span class="p">,</span> <span class="n">kv_seq_len</span><span class="p">,</span> <span class="n">n_groups</span><span class="p">,</span> <span class="n">d_head</span><span class="p">,</span>
<span class="lineno">298</span>            <span class="n">is_causal</span><span class="o">=</span><span class="n">causal</span><span class="p">,</span>
<span class="lineno">299</span>        <span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-24'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-24'>#</a>
            </div>
            <p>Split the combined batch and heads </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">302</span>        <span class="n">dq</span> <span class="o">=</span> <span class="n">dq</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">,</span> <span class="n">q_seq_len</span><span class="p">,</span> <span class="n">d_head</span><span class="p">)</span>
<span class="lineno">303</span>        <span class="n">dk</span> <span class="o">=</span> <span class="n">dk</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">k_heads</span><span class="p">,</span> <span class="n">kv_seq_len</span><span class="p">,</span> <span class="n">d_head</span><span class="p">)</span>
<span class="lineno">304</span>        <span class="n">dv</span> <span class="o">=</span> <span class="n">dv</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">k_heads</span><span class="p">,</span> <span class="n">kv_seq_len</span><span class="p">,</span> <span class="n">d_head</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-25'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-25'>#</a>
            </div>
            <p> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">307</span>        <span class="k">return</span> <span class="n">dq</span><span class="p">,</span> <span class="n">dk</span><span class="p">,</span> <span class="n">dv</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span>
<span class="lineno">308</span>
<span class="lineno">309</span>
<span class="lineno">310</span><span class="n">attention</span> <span class="o">=</span> <span class="n">AttentionFunc</span><span class="o">.</span><span class="n">apply</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-26'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-26'>#</a>
            </div>
            <h4>Configs for auto-tuning</h4>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">313</span><span class="k">def</span> <span class="nf">_get_autotune_configs</span><span class="p">(</span><span class="n">inner_loop</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">list</span><span class="p">:</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-27'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-27'>#</a>
            </div>
            
        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">318</span>    <span class="n">configs</span> <span class="o">=</span> <span class="p">[]</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-28'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-28'>#</a>
            </div>
            <p>Possible options for <code  class="highlight"><span></span><span class="n">BLOCK_Q</span></code>
 </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">321</span>    <span class="k">for</span> <span class="n">bq</span> <span class="ow">in</span> <span class="p">[</span><span class="mi">64</span><span class="p">,</span> <span class="mi">128</span><span class="p">,</span> <span class="mi">256</span><span class="p">]:</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-29'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-29'>#</a>
            </div>
            <p>Possible options for <code  class="highlight"><span></span><span class="n">BLOCK_K</span></code>
 </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">323</span>        <span class="k">for</span> <span class="n">bk</span> <span class="ow">in</span> <span class="p">[</span><span class="mi">64</span><span class="p">,</span> <span class="mi">128</span><span class="p">,</span> <span class="mi">256</span><span class="p">]:</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-30'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-30'>#</a>
            </div>
            <p>If the inner loop is along keys the <code  class="highlight"><span></span><span class="n">BLOCK_Q</span></code>
 must be a multiple of <code  class="highlight"><span></span><span class="n">BLOCK_K</span></code>
 for causal masking </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">325</span>            <span class="k">if</span> <span class="n">inner_loop</span> <span class="o">==</span> <span class="s1">&#39;key&#39;</span> <span class="ow">and</span> <span class="n">bq</span> <span class="o">%</span> <span class="n">bk</span> <span class="o">!=</span> <span class="mi">0</span><span class="p">:</span>
<span class="lineno">326</span>                <span class="k">continue</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-31'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-31'>#</a>
            </div>
            <p>Similarly when the inner loop is along queries </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">328</span>            <span class="k">if</span> <span class="n">inner_loop</span> <span class="o">==</span> <span class="s1">&#39;query&#39;</span> <span class="ow">and</span> <span class="n">bk</span> <span class="o">%</span> <span class="n">bq</span> <span class="o">!=</span> <span class="mi">0</span><span class="p">:</span>
<span class="lineno">329</span>                <span class="k">continue</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-32'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-32'>#</a>
            </div>
            <p>Number of stages and warps </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">332</span>            <span class="k">for</span> <span class="n">s</span> <span class="ow">in</span> <span class="p">[</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">]:</span>
<span class="lineno">333</span>                <span class="k">for</span> <span class="n">w</span> <span class="ow">in</span> <span class="p">[</span><span class="mi">4</span><span class="p">,</span> <span class="mi">8</span><span class="p">]:</span>
<span class="lineno">334</span>                    <span class="k">if</span> <span class="n">bq</span> <span class="o">*</span> <span class="n">bk</span> <span class="o">&lt;</span> <span class="mi">128</span> <span class="o">*</span> <span class="mi">128</span> <span class="ow">and</span> <span class="n">w</span> <span class="o">==</span> <span class="mi">8</span><span class="p">:</span>
<span class="lineno">335</span>                        <span class="k">continue</span>
<span class="lineno">336</span>
<span class="lineno">337</span>                    <span class="n">configs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">&#39;BLOCK_Q&#39;</span><span class="p">:</span> <span class="n">bq</span><span class="p">,</span> <span class="s1">&#39;BLOCK_K&#39;</span><span class="p">:</span> <span class="n">bk</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="n">s</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="n">w</span><span class="p">))</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-33'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-33'>#</a>
            </div>
            <p><strong>Use <code  class="highlight"><span></span><span class="k">return</span> <span class="n">configs</span></code>
 to autotune. Trying all combinations is slow for testing.</strong> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">340</span>    <span class="k">return</span> <span class="n">configs</span><span class="p">[:</span><span class="mi">1</span><span class="p">]</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-34'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-34'>#</a>
            </div>
            <h3>Triton kernel for Flash attention forward pass</h3>
<ul><li><code  class="highlight"><span></span><span class="n">t_q</span></code>
  queries <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8777699999999999em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqcj" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcw" style="">Q</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> </li>
<li><code  class="highlight"><span></span><span class="n">t_k</span></code>
  keys <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.969438em;vertical-align:-0.286108em;"></span><span class="mord coloredeq eqcg" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcv" style="margin-right:0.07153em">K</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span></span></span></span> </li>
<li><code  class="highlight"><span></span><span class="n">t_v</span></code>
  values <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.969438em;vertical-align:-0.286108em;"></span><span class="mord coloredeq eqck" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcx" style="margin-right:0.22222em">V</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span></span></span></span> </li>
<li><code  class="highlight"><span></span><span class="n">sm_scale_log2e</span></code>
  <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.93858em;vertical-align:-0.24414em;"></span><span class="mord coloredeq eqbv" style=""><span class="mord mathnormal" style="margin-right:0.03588em">σ</span></span><span class="mord coloredeq eqbr" style=""><span class="mop" style=""><span class="mop" style=""><span style="">l</span><span style="">o</span><span style="margin-right:0.01389em">g</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.20696799999999996em;"><span style="top:-2.4558600000000004em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqct" style="">2</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.24414em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord mathnormal coloredeq eqcy" style="">e</span></span></span></span></span></span></span> softmax scale multiplied by <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.93858em;vertical-align:-0.24414em;"></span><span class="mord coloredeq eqbr" style=""><span class="mop" style=""><span class="mop" style=""><span style="">l</span><span style="">o</span><span style="margin-right:0.01389em">g</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.20696799999999996em;"><span style="top:-2.4558600000000004em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqct" style="">2</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.24414em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord mathnormal coloredeq eqcy" style="">e</span></span></span></span></span></span></span> </li>
<li><code  class="highlight"><span></span><span class="n">t_lse</span></code>
  <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.277149em;vertical-align:-0.43581800000000004em;"></span><span class="mord coloredeq eqbb" style=""><span class="mop" style=""><span class="mop" style=""><span style="">l</span><span style="">o</span><span style="margin-right:0.01389em">g</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.20696799999999996em;"><span style="top:-2.4558600000000004em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqct" style="">2</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.24414em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mop" style=""><span class="mop op-symbol small-op" style="position:relative;top:-0.0000050000000000050004em">∑</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.16195399999999993em;"><span style="top:-2.40029em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.43581800000000004em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcy" style="">e</span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.841331em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqbu" style=""><span class="mord mathnormal mtight" style="margin-right:0.05764em">S</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3280857142857143em;"><span style="top:-2.357em;margin-left:-0.05764em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2818857142857143em;"><span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span> (out) </li>
<li><code  class="highlight"><span></span><span class="n">t_o</span></code>
  <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.83333em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqci" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.02778em">O</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> output </li>
<li><code  class="highlight"><span></span><span class="n">n_groups</span></code>
  number of groups in GQA </li>
<li><code  class="highlight"><span></span><span class="n">q_seq_len</span></code>
  query sequence length </li>
<li><code  class="highlight"><span></span><span class="n">kv_seq_len</span></code>
  key/value sequence length </li>
<li><code  class="highlight"><span></span><span class="n">d_head</span></code>
  number of dimensions in a head </li>
<li><code  class="highlight"><span></span><span class="n">BLOCK_Q</span></code>
  block size for query sequence length </li>
<li><code  class="highlight"><span></span><span class="n">BLOCK_K</span></code>
  block size for key sequence length </li>
<li><code  class="highlight"><span></span><span class="n">is_causal</span></code>
  whether causal attention</li></ul>
<p>Strides <code  class="highlight"><span></span><span class="n">z</span></code>
, <code  class="highlight"><span></span><span class="n">h</span></code>
, <code  class="highlight"><span></span><span class="n">m</span></code>
 and <code  class="highlight"><span></span><span class="n">d</span></code>
 denote the stride of the corresponding dimensions  (<code  class="highlight"><span></span><span class="n">batch_size</span></code>
, <code  class="highlight"><span></span><span class="n">n_heads</span></code>
, <code  class="highlight"><span></span><span class="n">q_seq_len</span></code>
, <code  class="highlight"><span></span><span class="n">d_head</span></code>
) in the query. Stride <code  class="highlight"><span></span><span class="n">n</span></code>
 denote the stride on <code  class="highlight"><span></span><span class="n">kv_seq_len</span></code>
 of key.</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">343</span><span class="nd">@triton</span><span class="o">.</span><span class="n">autotune</span><span class="p">(</span><span class="n">_get_autotune_configs</span><span class="p">(</span><span class="n">inner_loop</span><span class="o">=</span><span class="s1">&#39;key&#39;</span><span class="p">),</span>
<span class="lineno">344</span>                 <span class="n">key</span><span class="o">=</span><span class="p">[</span><span class="s2">&quot;q_seq_len&quot;</span><span class="p">,</span> <span class="s2">&quot;kv_seq_len&quot;</span><span class="p">,</span> <span class="s2">&quot;d_head&quot;</span><span class="p">,</span> <span class="s2">&quot;n_groups&quot;</span><span class="p">,</span> <span class="s2">&quot;is_causal&quot;</span><span class="p">])</span>
<span class="lineno">345</span><span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
<span class="lineno">346</span><span class="k">def</span> <span class="nf">_attn_fwd</span><span class="p">(</span><span class="n">t_q</span><span class="p">,</span> <span class="n">t_k</span><span class="p">,</span> <span class="n">t_v</span><span class="p">,</span> <span class="n">sm_scale_log2e</span><span class="p">,</span> <span class="n">t_lse</span><span class="p">,</span> <span class="n">t_o</span><span class="p">,</span>
<span class="lineno">347</span>              <span class="n">n_groups</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
<span class="lineno">348</span>              <span class="n">q_seq_len</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
<span class="lineno">349</span>              <span class="n">kv_seq_len</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
<span class="lineno">350</span>              <span class="n">d_head</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
<span class="lineno">351</span>              <span class="n">is_causal</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
<span class="lineno">352</span>              <span class="n">BLOCK_Q</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
<span class="lineno">353</span>              <span class="n">BLOCK_K</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
<span class="lineno">354</span>              <span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-35'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-35'>#</a>
            </div>
            <p>We are computing the attention for <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.83333em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqci" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.02778em">O</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> for <code  class="highlight"><span></span><span class="n">i</span></code>
 ... `i + BLOCK_Q&#x27; in batch/head combination <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord mathnormal" style="margin-right:0.04398em;">z</span></span></span></span></span>. </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">378</span>    <span class="n">i</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="lineno">379</span>    <span class="n">z</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span> <span class="o">//</span> <span class="n">n_groups</span>
<span class="lineno">380</span>    <span class="n">g</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span> <span class="o">%</span> <span class="n">n_groups</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-36'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-36'>#</a>
            </div>
            <h4>Create block pointers</h4>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">383</span>    <span class="n">p_q</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">make_block_ptr</span><span class="p">(</span><span class="n">t_q</span> <span class="o">+</span> <span class="n">z</span> <span class="o">*</span> <span class="n">n_groups</span> <span class="o">*</span> <span class="n">q_seq_len</span> <span class="o">*</span> <span class="n">d_head</span> <span class="o">+</span> <span class="n">g</span> <span class="o">*</span> <span class="n">q_seq_len</span> <span class="o">*</span> <span class="n">d_head</span><span class="p">,</span>
<span class="lineno">384</span>                            <span class="p">(</span><span class="n">q_seq_len</span><span class="p">,</span> <span class="n">d_head</span><span class="p">),</span>
<span class="lineno">385</span>                            <span class="p">(</span><span class="n">d_head</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
<span class="lineno">386</span>                            <span class="p">(</span><span class="n">i</span> <span class="o">*</span> <span class="n">BLOCK_Q</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span>
<span class="lineno">387</span>                            <span class="p">(</span><span class="n">BLOCK_Q</span><span class="p">,</span> <span class="n">d_head</span><span class="p">),</span>
<span class="lineno">388</span>                            <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">))</span>
<span class="lineno">389</span>    <span class="n">p_v</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">make_block_ptr</span><span class="p">(</span><span class="n">t_v</span> <span class="o">+</span> <span class="n">z</span> <span class="o">*</span> <span class="n">kv_seq_len</span> <span class="o">*</span> <span class="n">d_head</span><span class="p">,</span>
<span class="lineno">390</span>                            <span class="p">(</span><span class="n">kv_seq_len</span><span class="p">,</span> <span class="n">d_head</span><span class="p">),</span>
<span class="lineno">391</span>                            <span class="p">(</span><span class="n">d_head</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
<span class="lineno">392</span>                            <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span>
<span class="lineno">393</span>                            <span class="p">(</span><span class="n">BLOCK_K</span><span class="p">,</span> <span class="n">d_head</span><span class="p">),</span>
<span class="lineno">394</span>                            <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">))</span>
<span class="lineno">395</span>    <span class="n">p_kT</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">make_block_ptr</span><span class="p">(</span><span class="n">t_k</span> <span class="o">+</span> <span class="n">z</span> <span class="o">*</span> <span class="n">kv_seq_len</span> <span class="o">*</span> <span class="n">d_head</span><span class="p">,</span>
<span class="lineno">396</span>                             <span class="p">(</span><span class="n">d_head</span><span class="p">,</span> <span class="n">kv_seq_len</span><span class="p">),</span>
<span class="lineno">397</span>                             <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">d_head</span><span class="p">),</span>
<span class="lineno">398</span>                             <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span>
<span class="lineno">399</span>                             <span class="p">(</span><span class="n">d_head</span><span class="p">,</span> <span class="n">BLOCK_K</span><span class="p">),</span>
<span class="lineno">400</span>                             <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
<span class="lineno">401</span>    <span class="n">p_o</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">make_block_ptr</span><span class="p">(</span><span class="n">t_o</span> <span class="o">+</span> <span class="n">z</span> <span class="o">*</span> <span class="n">n_groups</span> <span class="o">*</span> <span class="n">q_seq_len</span> <span class="o">*</span> <span class="n">d_head</span> <span class="o">+</span> <span class="n">g</span> <span class="o">*</span> <span class="n">q_seq_len</span> <span class="o">*</span> <span class="n">d_head</span><span class="p">,</span>
<span class="lineno">402</span>                            <span class="p">(</span><span class="n">q_seq_len</span><span class="p">,</span> <span class="n">d_head</span><span class="p">),</span>
<span class="lineno">403</span>                            <span class="p">(</span><span class="n">d_head</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
<span class="lineno">404</span>                            <span class="p">(</span><span class="n">i</span> <span class="o">*</span> <span class="n">BLOCK_Q</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span>
<span class="lineno">405</span>                            <span class="p">(</span><span class="n">BLOCK_Q</span><span class="p">,</span> <span class="n">d_head</span><span class="p">),</span>
<span class="lineno">406</span>                            <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">))</span>
<span class="lineno">407</span>    <span class="n">p_lse</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">make_block_ptr</span><span class="p">(</span><span class="n">t_lse</span> <span class="o">+</span> <span class="n">z</span> <span class="o">*</span> <span class="n">n_groups</span> <span class="o">*</span> <span class="n">q_seq_len</span> <span class="o">+</span> <span class="n">g</span> <span class="o">*</span> <span class="n">q_seq_len</span><span class="p">,</span>
<span class="lineno">408</span>                              <span class="p">(</span><span class="n">q_seq_len</span><span class="p">,),</span>
<span class="lineno">409</span>                              <span class="p">(</span><span class="mi">1</span><span class="p">,),</span>
<span class="lineno">410</span>                              <span class="p">(</span><span class="n">i</span> <span class="o">*</span> <span class="n">BLOCK_Q</span><span class="p">,),</span>
<span class="lineno">411</span>                              <span class="p">(</span><span class="n">BLOCK_Q</span><span class="p">,),</span>
<span class="lineno">412</span>                              <span class="p">(</span><span class="mi">0</span><span class="p">,))</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-37'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-37'>#</a>
            </div>
            <p>Initialize offsets </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">415</span>    <span class="n">offs_i</span> <span class="o">=</span> <span class="n">i</span> <span class="o">*</span> <span class="n">BLOCK_Q</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_Q</span><span class="p">)</span>
<span class="lineno">416</span>    <span class="n">offs_j</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_K</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-38'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-38'>#</a>
            </div>
            <p>Mask for <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8777699999999999em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqcw" style=""><span class="mord mathnormal" style="">Q</span></span></span></span></span></span> for the last block </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">419</span>    <span class="n">i_mask</span> <span class="o">=</span> <span class="n">offs_i</span> <span class="o">&lt;</span> <span class="n">q_seq_len</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-39'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-39'>#</a>
            </div>
            <p>Initialize <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqcm" style=""><span class="mord" style=""><span class="mord mathnormal" style="">m</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> and <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.84444em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqcl" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.01968em">l</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.01968em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span>. <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqcm" style=""><span class="mord" style=""><span class="mord mathnormal" style="">m</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> is initialized to <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.77777em;vertical-align:-0.08333em;"></span><span class="mord">−</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mop">in<span style="margin-right:0.07778em;">f</span></span></span></span></span></span> and <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.84444em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqcl" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.01968em">l</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.01968em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> to <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqcs" style=""><span class="mord" style="">1</span></span></span></span></span></span>. So in the first update, the effect of initial <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.84444em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqcl" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.01968em">l</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.01968em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> is <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.02998em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqz" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcy" style="">e</span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.87998em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqcm" style=""><span class="mord mathnormal mtight" style="">m</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3280857142857143em;"><span style="top:-2.357em;margin-left:0em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.143em;"><span></span></span></span></span></span></span></span><span class="mbin mtight" style="">−</span><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">m</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.7385428571428572em;"><span style="top:-2.214em;margin-left:0em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span style="top:-2.931em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord text mtight" style=""><span class="mord mtight" style="">n</span><span class="mord mtight" style=""><span class="mord mtight coloredeq eqcy" style="">e</span></span><span class="mord mtight" style="">w</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286em;"><span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span><span class="mord coloredeq eqcl" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.01968em">l</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.01968em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqcr" style=""><span class="mord" style="">0</span></span></span></span></span></span>.</p>
<p><code  class="highlight"><span></span><span class="n">b_m</span></code>
 will be storing <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.93858em;vertical-align:-0.24414em;"></span><span class="mord coloredeq eqcm" style=""><span class="mord" style=""><span class="mord mathnormal" style="">m</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mord coloredeq eqbr" style=""><span class="mop" style=""><span class="mop" style=""><span style="">l</span><span style="">o</span><span style="margin-right:0.01389em">g</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.20696799999999996em;"><span style="top:-2.4558600000000004em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqct" style="">2</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.24414em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord mathnormal coloredeq eqcy" style="">e</span></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">425</span>    <span class="n">b_m</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">i_mask</span><span class="p">,</span> <span class="o">-</span><span class="nb">float</span><span class="p">(</span><span class="s2">&quot;inf&quot;</span><span class="p">),</span> <span class="mf">0.0</span><span class="p">)</span>
<span class="lineno">426</span>    <span class="n">b_l</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">i_mask</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-40'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-40'>#</a>
            </div>
            <p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.83333em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqci" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.02778em">O</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">429</span>    <span class="n">b_o</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="n">BLOCK_Q</span><span class="p">,</span> <span class="n">d_head</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">HI_PRES_TL</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-41'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-41'>#</a>
            </div>
            <p>Load <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8777699999999999em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqcj" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcw" style="">Q</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> outside the loop since it will be reused through out the loop over <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.969438em;vertical-align:-0.286108em;"></span><span class="mord coloredeq eqcg" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcv" style="margin-right:0.07153em">K</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span></span></span></span>. </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">432</span>    <span class="n">b_q</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">p_q</span><span class="p">,</span> <span class="n">boundary_check</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,),</span> <span class="n">padding_option</span><span class="o">=</span><span class="s2">&quot;zero&quot;</span><span class="p">)</span>
<span class="lineno">433</span>
<span class="lineno">434</span>    <span class="k">if</span> <span class="n">is_causal</span><span class="p">:</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-42'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-42'>#</a>
            </div>
            <p>Inner loop upto the diagonal block </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">436</span>        <span class="n">b_o</span><span class="p">,</span> <span class="n">b_l</span><span class="p">,</span> <span class="n">b_m</span> <span class="o">=</span> <span class="n">_attn_fwd_inner</span><span class="p">(</span><span class="n">b_o</span><span class="p">,</span> <span class="n">b_l</span><span class="p">,</span> <span class="n">b_m</span><span class="p">,</span> <span class="n">b_q</span><span class="p">,</span>
<span class="lineno">437</span>                                        <span class="n">p_kT</span><span class="p">,</span> <span class="n">p_v</span><span class="p">,</span>
<span class="lineno">438</span>                                        <span class="n">sm_scale_log2e</span><span class="p">,</span>
<span class="lineno">439</span>                                        <span class="n">BLOCK_Q</span><span class="p">,</span> <span class="n">d_head</span><span class="p">,</span> <span class="n">BLOCK_K</span><span class="p">,</span>
<span class="lineno">440</span>                                        <span class="n">offs_i</span><span class="p">,</span> <span class="n">offs_j</span><span class="p">,</span>
<span class="lineno">441</span>                                        <span class="n">j</span><span class="o">=</span><span class="n">tl</span><span class="o">.</span><span class="n">full</span><span class="p">([],</span> <span class="mi">0</span><span class="p">,</span> <span class="n">tl</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>  <span class="c1"># type: ignore</span>
<span class="lineno">442</span>                                        <span class="n">steps</span><span class="o">=</span><span class="p">(</span><span class="n">i</span> <span class="o">*</span> <span class="n">BLOCK_Q</span><span class="p">)</span> <span class="o">//</span> <span class="n">BLOCK_K</span><span class="p">,</span>
<span class="lineno">443</span>                                        <span class="n">MASK</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="lineno">444</span>                                        <span class="n">q_seq_len</span><span class="o">=</span><span class="n">q_seq_len</span><span class="p">,</span>
<span class="lineno">445</span>                                        <span class="n">kv_seq_len</span><span class="o">=</span><span class="n">kv_seq_len</span>
<span class="lineno">446</span>                                        <span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-43'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-43'>#</a>
            </div>
            <p>Diagonal block with masking within it </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">448</span>        <span class="n">b_o</span><span class="p">,</span> <span class="n">b_l</span><span class="p">,</span> <span class="n">b_m</span> <span class="o">=</span> <span class="n">_attn_fwd_inner</span><span class="p">(</span><span class="n">b_o</span><span class="p">,</span> <span class="n">b_l</span><span class="p">,</span> <span class="n">b_m</span><span class="p">,</span> <span class="n">b_q</span><span class="p">,</span> <span class="n">p_kT</span><span class="p">,</span> <span class="n">p_v</span><span class="p">,</span>
<span class="lineno">449</span>                                        <span class="n">sm_scale_log2e</span><span class="p">,</span>
<span class="lineno">450</span>                                        <span class="n">BLOCK_Q</span><span class="p">,</span> <span class="n">d_head</span><span class="p">,</span> <span class="n">BLOCK_K</span><span class="p">,</span>
<span class="lineno">451</span>                                        <span class="n">offs_i</span><span class="p">,</span> <span class="n">offs_j</span><span class="p">,</span>
<span class="lineno">452</span>                                        <span class="n">j</span><span class="o">=</span><span class="n">i</span> <span class="o">*</span> <span class="n">BLOCK_Q</span><span class="p">,</span>
<span class="lineno">453</span>                                        <span class="n">steps</span><span class="o">=</span><span class="n">BLOCK_Q</span> <span class="o">//</span> <span class="n">BLOCK_K</span><span class="p">,</span>
<span class="lineno">454</span>                                        <span class="n">MASK</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="lineno">455</span>                                        <span class="n">q_seq_len</span><span class="o">=</span><span class="n">q_seq_len</span><span class="p">,</span>
<span class="lineno">456</span>                                        <span class="n">kv_seq_len</span><span class="o">=</span><span class="n">kv_seq_len</span>
<span class="lineno">457</span>                                        <span class="p">)</span>
<span class="lineno">458</span>    <span class="k">else</span><span class="p">:</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-44'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-44'>#</a>
            </div>
            <p>Iterate through all <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.969438em;vertical-align:-0.286108em;"></span><span class="mord coloredeq eqcg" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcv" style="margin-right:0.07153em">K</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">460</span>        <span class="n">b_o</span><span class="p">,</span> <span class="n">b_l</span><span class="p">,</span> <span class="n">b_m</span> <span class="o">=</span> <span class="n">_attn_fwd_inner</span><span class="p">(</span><span class="n">b_o</span><span class="p">,</span> <span class="n">b_l</span><span class="p">,</span> <span class="n">b_m</span><span class="p">,</span> <span class="n">b_q</span><span class="p">,</span> <span class="n">p_kT</span><span class="p">,</span> <span class="n">p_v</span><span class="p">,</span>
<span class="lineno">461</span>                                        <span class="n">sm_scale_log2e</span><span class="p">,</span>
<span class="lineno">462</span>                                        <span class="n">BLOCK_Q</span><span class="p">,</span> <span class="n">d_head</span><span class="p">,</span> <span class="n">BLOCK_K</span><span class="p">,</span>
<span class="lineno">463</span>                                        <span class="n">offs_i</span><span class="p">,</span> <span class="n">offs_j</span><span class="p">,</span>
<span class="lineno">464</span>                                        <span class="n">j</span><span class="o">=</span><span class="n">tl</span><span class="o">.</span><span class="n">full</span><span class="p">([],</span> <span class="mi">0</span><span class="p">,</span> <span class="n">tl</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>  <span class="c1"># type: ignore</span>
<span class="lineno">465</span>                                        <span class="n">steps</span><span class="o">=</span><span class="n">tl</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">kv_seq_len</span><span class="p">,</span> <span class="n">BLOCK_K</span><span class="p">),</span>
<span class="lineno">466</span>                                        <span class="n">MASK</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="lineno">467</span>                                        <span class="n">q_seq_len</span><span class="o">=</span><span class="n">q_seq_len</span><span class="p">,</span>
<span class="lineno">468</span>                                        <span class="n">kv_seq_len</span><span class="o">=</span><span class="n">kv_seq_len</span>
<span class="lineno">469</span>                                        <span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-45'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-45'>#</a>
            </div>
            <p>Store LSE <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.93858em;vertical-align:-0.24414em;"></span><span class="mord coloredeq eqbp" style=""><span class="mop" style=""><span class="mop" style=""><span style="">l</span><span style="">o</span><span style="margin-right:0.01389em">g</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.20696799999999996em;"><span style="top:-2.4558600000000004em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqct" style="">2</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.24414em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord coloredeq eqch" style=""><span class="mord mathnormal" style="">L</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1.20001em;vertical-align:-0.35001em;"></span><span class="mop"><span class="mop">lo<span style="margin-right:0.01389em;">g</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.20696799999999996em;"><span style="top:-2.4558600000000004em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight coloredeq eqct" style=""><span class="mord mtight" style="">2</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.24414em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="delimsizing size1">(</span></span><span class="mord coloredeq eqcl" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.01968em">l</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.01968em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">∗</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1.20001em;vertical-align:-0.35001em;"></span><span class="mord"><span class="mord coloredeq eqcy" style=""><span class="mord mathnormal" style="">e</span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.664392em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqcm" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">m</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3280857142857143em;"><span style="top:-2.357em;margin-left:0em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.143em;"><span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span><span class="mord"><span class="delimsizing size1">)</span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:0.93858em;vertical-align:-0.24414em;"></span><span class="mop"><span class="mop">lo<span style="margin-right:0.01389em;">g</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.20696799999999996em;"><span style="top:-2.4558600000000004em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight coloredeq eqct" style=""><span class="mord mtight" style="">2</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.24414em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqcl" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.01968em">l</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.01968em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqcm" style=""><span class="mord" style=""><span class="mord mathnormal" style="">m</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mord mathnormal" style="margin-right:0.01968em;">l</span><span class="mord mathnormal">o</span><span class="mord mathnormal" style="margin-right:0.03588em;">g</span><span class="mord coloredeq eqct" style=""><span class="mord" style="">2</span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">472</span>    <span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">p_lse</span><span class="p">,</span> <span class="n">b_m</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">math</span><span class="o">.</span><span class="n">log2</span><span class="p">(</span><span class="n">b_l</span><span class="p">),</span> <span class="n">boundary_check</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,))</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-46'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-46'>#</a>
            </div>
            <p>Store <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.4993329999999996em;vertical-align:-0.44509999999999994em;"></span><span class="mord coloredeq eqy" style=""><span class="mord" style=""><span class="mord coloredeq eqci" style=""><span class="mord mathnormal" style="margin-right:0.02778em">O</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel" style="">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mord" style=""><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.0542329999999998em;"><span style="top:-2.655em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqcl" style=""><span class="mord mathnormal mtight" style="margin-right:0.01968em">l</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3280857142857143em;"><span style="top:-2.357em;margin-left:-0.01968em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.143em;"><span></span></span></span></span></span></span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em"></span></span><span style="top:-3.4101em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqbo" style=""><span class="mord accent mtight" style=""><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.9201899999999998em;"><span style="top:-2.7em;"><span class="pstrut" style="height:2.7em;"></span><span class="mord mathnormal mtight" style="margin-right:0.02778em">O</span></span><span style="top:-3.3023300000000004em;"><span class="pstrut" style="height:2.7em;"></span><span class="accent-body" style="left:-0.16666em;"><span class="mord mtight" style="">~</span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3280857142857143em;"><span style="top:-2.357em;margin-left:-0.02778em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.143em;"><span></span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.44509999999999994em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">474</span>    <span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">p_o</span><span class="p">,</span> <span class="p">(</span><span class="n">b_o</span> <span class="o">/</span> <span class="n">b_l</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">])</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">t_o</span><span class="o">.</span><span class="n">type</span><span class="o">.</span><span class="n">element_ty</span><span class="p">),</span> <span class="n">boundary_check</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,))</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-47'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-47'>#</a>
            </div>
            <h4>Inner loop to calculate <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.83333em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqci" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.02778em">O</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span></h4>
<p>This iterates through keys and values starting from <code  class="highlight"><span></span><span class="n">j</span></code>
 for <code  class="highlight"><span></span><span class="n">steps</span></code>
 number of steps. In each step it processes <code  class="highlight"><span></span><span class="n">BLOCK_K</span></code>
 entries of keys/values.</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">477</span><span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
<span class="lineno">478</span><span class="k">def</span> <span class="nf">_attn_fwd_inner</span><span class="p">(</span><span class="n">b_o</span><span class="p">,</span> <span class="n">b_l</span><span class="p">,</span> <span class="n">b_m</span><span class="p">,</span> <span class="n">b_q</span><span class="p">,</span>
<span class="lineno">479</span>                    <span class="n">p_kT</span><span class="p">,</span> <span class="n">p_v</span><span class="p">,</span>
<span class="lineno">480</span>                    <span class="n">sm_scale_log2e</span><span class="p">,</span>
<span class="lineno">481</span>                    <span class="n">BLOCK_Q</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
<span class="lineno">482</span>                    <span class="n">d_head</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
<span class="lineno">483</span>                    <span class="n">BLOCK_K</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
<span class="lineno">484</span>                    <span class="n">offs_i</span><span class="p">,</span> <span class="n">offs_j</span><span class="p">,</span>
<span class="lineno">485</span>                    <span class="n">j</span><span class="p">,</span>
<span class="lineno">486</span>                    <span class="n">steps</span><span class="p">,</span>
<span class="lineno">487</span>                    <span class="n">MASK</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
<span class="lineno">488</span>                    <span class="n">q_seq_len</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
<span class="lineno">489</span>                    <span class="n">kv_seq_len</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span>
<span class="lineno">490</span>                    <span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-48'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-48'>#</a>
            </div>
            
        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">497</span>    <span class="n">tl</span><span class="o">.</span><span class="n">static_assert</span><span class="p">(</span><span class="n">BLOCK_Q</span> <span class="o">%</span> <span class="n">BLOCK_K</span> <span class="o">==</span> <span class="mi">0</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-49'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-49'>#</a>
            </div>
            <p>Move <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.969438em;vertical-align:-0.286108em;"></span><span class="mord coloredeq eqcg" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcv" style="margin-right:0.07153em">K</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span></span></span></span> and <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.969438em;vertical-align:-0.286108em;"></span><span class="mord coloredeq eqck" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcx" style="margin-right:0.22222em">V</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span></span></span></span> pointers </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">500</span>    <span class="n">p_kT</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">advance</span><span class="p">(</span><span class="n">p_kT</span><span class="p">,</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">j</span><span class="p">))</span>
<span class="lineno">501</span>    <span class="n">p_v</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">advance</span><span class="p">(</span><span class="n">p_v</span><span class="p">,</span> <span class="p">(</span><span class="n">j</span><span class="p">,</span> <span class="mi">0</span><span class="p">))</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-50'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-50'>#</a>
            </div>
            <p>Iterate over <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqcv" style=""><span class="mord mathnormal" style="margin-right:0.07153em">K</span></span></span></span></span></span>, <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqcx" style=""><span class="mord mathnormal" style="margin-right:0.22222em">V</span></span></span></span></span></span> and update <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.0701899999999998em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqbo" style=""><span class="mord" style=""><span class="mord accent" style=""><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.9201899999999998em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord mathnormal" style="margin-right:0.02778em">O</span></span><span style="top:-3.6023300000000003em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.16666em;"><span class="mord" style="">~</span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> and <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.84444em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqcl" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.01968em">l</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.01968em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">504</span>    <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">steps</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-51'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-51'>#</a>
            </div>
            <p>Load <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.200669em;vertical-align:-0.286108em;"></span><span class="mord coloredeq eqbx" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord coloredeq eqcg" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcv" style="margin-right:0.07153em">K</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.914561em;"><span style="top:-3.1362300000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.13889em">T</span></span></span></span></span></span></span></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">506</span>        <span class="n">b_kT</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">p_kT</span><span class="p">,</span> <span class="n">boundary_check</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,),</span> <span class="n">padding_option</span><span class="o">=</span><span class="s2">&quot;zero&quot;</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-52'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-52'>#</a>
            </div>
            <p>Compute <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mopen">(</span><span class="mop">lo<span style="margin-right:0.01389em;">g</span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqct" style=""><span class="mord" style="">2</span></span><span class="mclose">)</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.05764em;">S</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.05764em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight coloredeq eqcz" style=""><span class="mord mathnormal mtight" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mord coloredeq eqda" style=""><span class="mord mathnormal" style="margin-right:0.05724em">j</span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1.200669em;vertical-align:-0.286108em;"></span><span class="mopen">(</span><span class="mop">lo<span style="margin-right:0.01389em;">g</span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqct" style=""><span class="mord" style="">2</span></span><span class="mclose">)</span><span class="mord coloredeq eqbv" style=""><span class="mord mathnormal" style="margin-right:0.03588em">σ</span></span><span class="mord coloredeq eqcj" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcw" style="">Q</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mord coloredeq eqbx" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord coloredeq eqcg" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcv" style="margin-right:0.07153em">K</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.914561em;"><span style="top:-3.1362300000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.13889em">T</span></span></span></span></span></span></span></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">508</span>        <span class="n">b_s</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">b_q</span><span class="p">,</span> <span class="n">b_kT</span><span class="p">,</span> <span class="n">out_dtype</span><span class="o">=</span><span class="n">HI_PRES_TL</span><span class="p">)</span>
<span class="lineno">509</span>        <span class="n">b_s</span> <span class="o">=</span> <span class="n">b_s</span> <span class="o">*</span> <span class="n">sm_scale_log2e</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-53'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-53'>#</a>
            </div>
            <p>Apply causal mask </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">512</span>        <span class="k">if</span> <span class="n">MASK</span><span class="p">:</span>
<span class="lineno">513</span>            <span class="n">causal_mask</span> <span class="o">=</span> <span class="n">offs_i</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">&gt;=</span> <span class="p">(</span><span class="n">j</span> <span class="o">+</span> <span class="n">offs_j</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:])</span>
<span class="lineno">514</span>            <span class="n">b_s</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">causal_mask</span><span class="p">,</span> <span class="n">b_s</span><span class="p">,</span> <span class="o">-</span><span class="nb">float</span><span class="p">(</span><span class="s2">&quot;inf&quot;</span><span class="p">))</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-54'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-54'>#</a>
            </div>
            <p>Mask out if the block is beyond the end of <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.969438em;vertical-align:-0.286108em;"></span><span class="mord coloredeq eqcg" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcv" style="margin-right:0.07153em">K</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">517</span>        <span class="n">j_mask</span> <span class="o">=</span> <span class="p">(</span><span class="n">j</span> <span class="o">+</span> <span class="n">offs_j</span><span class="p">)</span> <span class="o">&lt;</span> <span class="n">kv_seq_len</span>
<span class="lineno">518</span>        <span class="n">b_s</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">j_mask</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:],</span> <span class="n">b_s</span><span class="p">,</span> <span class="o">-</span><span class="nb">float</span><span class="p">(</span><span class="s2">&quot;inf&quot;</span><span class="p">))</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-55'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-55'>#</a>
            </div>
            <p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.008664em;vertical-align:-0.258664em;"></span><span class="mopen">(</span><span class="mord coloredeq eqbr" style=""><span class="mop" style=""><span class="mop" style=""><span style="">l</span><span style="">o</span><span style="margin-right:0.01389em">g</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.20696799999999996em;"><span style="top:-2.4558600000000004em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqct" style="">2</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.24414em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord mathnormal coloredeq eqcy" style="">e</span></span></span><span class="mclose">)</span><span class="mord"><span class="mord mathnormal">m</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.664392em;"><span style="top:-2.441336em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqcz" style=""><span class="mord mathnormal mtight" style="">i</span></span></span></span></span><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">n</span><span class="mord mtight coloredeq eqcy" style=""><span class="mord mtight" style="">e</span></span><span class="mord mtight">w</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.258664em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1.355544em;vertical-align:-0.412972em;"></span><span class="mop">max</span><span class="mopen">((</span><span class="mord coloredeq eqbr" style=""><span class="mop" style=""><span class="mop" style=""><span style="">l</span><span style="">o</span><span style="margin-right:0.01389em">g</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.20696799999999996em;"><span style="top:-2.4558600000000004em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqct" style="">2</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.24414em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord mathnormal coloredeq eqcy" style="">e</span></span></span><span class="mclose">)</span><span class="mord coloredeq eqcm" style=""><span class="mord" style=""><span class="mord mathnormal" style="">m</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mop"><span class="mop">max</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.942572em;"><span style="top:-2.4231360000000004em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqda" style=""><span class="mord mathnormal mtight" style="margin-right:0.05724em">j</span></span><span class="mrel mtight">=</span><span class="mord mtight coloredeq eqda" style=""><span class="mord mathnormal mtight" style="margin-right:0.05724em">j</span></span><span class="mord mtight coloredeq eqcs" style=""><span class="mord mtight" style="">1</span></span></span></span></span><span style="top:-3.1809080000000005em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqda" style=""><span class="mord mathnormal mtight" style="margin-right:0.05724em">j</span></span><span class="mord mtight coloredeq eqct" style=""><span class="mord mtight" style="">2</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.412972em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord coloredeq eqbr" style=""><span class="mop" style=""><span class="mop" style=""><span style="">l</span><span style="">o</span><span style="margin-right:0.01389em">g</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.20696799999999996em;"><span style="top:-2.4558600000000004em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqct" style="">2</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.24414em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord mathnormal coloredeq eqcy" style="">e</span></span></span><span class="mclose">)</span><span class="mord coloredeq eqbu" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.05764em">S</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.05764em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">521</span>        <span class="n">b_m_new</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">maximum</span><span class="p">(</span><span class="n">b_m</span><span class="p">,</span> <span class="n">tl</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">b_s</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">))</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-56'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-56'>#</a>
            </div>
            <span ><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:3.1959999999999997em;vertical-align:-1.3479999999999999em;"></span><span class="mord"><span class="mtable"><span class="col-align-r"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.8479999999999999em;"><span style="top:-3.91em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord"><span class="mord accent"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.9201899999999998em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord mathnormal" style="margin-right:0.13889em;">P</span></span><span style="top:-3.6023300000000003em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.16666em;"><span class="mord">~</span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqcz" style=""><span class="mord mathnormal mtight" style="">i</span></span><span class="mord mtight coloredeq eqda" style=""><span class="mord mathnormal mtight" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span><span style="top:-2.3120000000000003em;"><span class="pstrut" style="height:3em;"></span><span class="mord"></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:1.3479999999999999em;"><span></span></span></span></span></span><span class="col-align-l"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.8479999999999999em;"><span style="top:-3.91em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord"></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mord"><span class="mord coloredeq eqcy" style=""><span class="mord mathnormal" style="">e</span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.9379999999999998em;"><span style="top:-3.113em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mopen mtight">(</span><span class="mord mtight coloredeq eqbu" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.05764em">S</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3280857142857143em;"><span style="top:-2.357em;margin-left:-0.05764em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2818857142857143em;"><span></span></span></span></span></span></span></span><span class="mbin mtight">−</span><span class="mord mtight"><span class="mord mtight coloredeq eqcm" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">m</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3280857142857143em;"><span style="top:-2.357em;margin-left:0em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.143em;"><span></span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.7385428571428572em;"><span style="top:-2.931em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">n</span><span class="mord mtight coloredeq eqcy" style=""><span class="mord mtight" style="">e</span></span><span class="mord mtight">w</span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span><span style="top:-2.3120000000000003em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord"></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mord"><span class="mord coloredeq eqct" style=""><span class="mord" style="">2</span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.9379999999999998em;"><span style="top:-3.113em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mopen mtight">(</span><span class="mord mtight coloredeq eqbr" style=""><span class="mop mtight" style=""><span class="mop mtight" style=""><span class="mtight" style="">l</span><span class="mtight" style="">o</span><span class="mtight" style="margin-right:0.01389em">g</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.19444571428571428em;"><span style="top:-2.2341314285714287em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqct" style="">2</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.26586857142857145em;"><span></span></span></span></span></span></span><span class="mspace mtight" style="margin-right:0.19516666666666668em;"></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcy" style="">e</span></span></span><span class="mclose mtight">)</span><span class="mord mtight coloredeq eqbu" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.05764em">S</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3280857142857143em;"><span style="top:-2.357em;margin-left:-0.05764em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2818857142857143em;"><span></span></span></span></span></span></span></span><span class="mbin mtight">−</span><span class="mopen mtight">(</span><span class="mord mtight coloredeq eqbr" style=""><span class="mop mtight" style=""><span class="mop mtight" style=""><span class="mtight" style="">l</span><span class="mtight" style="">o</span><span class="mtight" style="margin-right:0.01389em">g</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.19444571428571428em;"><span style="top:-2.2341314285714287em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqct" style="">2</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.26586857142857145em;"><span></span></span></span></span></span></span><span class="mspace mtight" style="margin-right:0.19516666666666668em;"></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcy" style="">e</span></span></span><span class="mclose mtight">)</span><span class="mord mtight"><span class="mord mtight coloredeq eqcm" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">m</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3280857142857143em;"><span style="top:-2.357em;margin-left:0em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.143em;"><span></span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.7385428571428572em;"><span style="top:-2.931em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">n</span><span class="mord mtight coloredeq eqcy" style=""><span class="mord mtight" style="">e</span></span><span class="mord mtight">w</span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:1.3479999999999999em;"><span></span></span></span></span></span></span></span></span></span></span></span></span><p> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">527</span>        <span class="n">b_p</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">math</span><span class="o">.</span><span class="n">exp2</span><span class="p">(</span><span class="n">b_s</span> <span class="o">-</span> <span class="n">b_m_new</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">])</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-57'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-57'>#</a>
            </div>
            <p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.400382em;vertical-align:-0.43581800000000004em;"></span><span class="mord coloredeq eqx" style=""><span class="mop" style=""><span class="mop op-symbol small-op" style="position:relative;top:-0.0000050000000000050004em">∑</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.964564em;"><span style="top:-2.40029em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span><span class="mrel mtight" style="">=</span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span><span class="mord mtight" style=""><span class="mord mtight coloredeq eqcs" style="">1</span></span></span></span></span><span style="top:-3.2029000000000005em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span><span class="mord mtight" style=""><span class="mord mtight coloredeq eqct" style="">2</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.43581800000000004em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord accent" style=""><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.9201899999999998em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord mathnormal" style="margin-right:0.13889em">P</span></span><span style="top:-3.6023300000000003em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.16666em;"><span class="mord" style="">~</span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">530</span>        <span class="n">b_l_new</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">b_p</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-58'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-58'>#</a>
            </div>
            <p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.87998em;vertical-align:0em;"></span><span class="mord coloredeq eqz" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcy" style="">e</span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.87998em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqcm" style=""><span class="mord mathnormal mtight" style="">m</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3280857142857143em;"><span style="top:-2.357em;margin-left:0em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.143em;"><span></span></span></span></span></span></span></span><span class="mbin mtight" style="">−</span><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">m</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.7385428571428572em;"><span style="top:-2.214em;margin-left:0em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span style="top:-2.931em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord text mtight" style=""><span class="mord mtight" style="">n</span><span class="mord mtight" style=""><span class="mord mtight coloredeq eqcy" style="">e</span></span><span class="mord mtight" style="">w</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286em;"><span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">532</span>        <span class="n">b_m_m_new</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">math</span><span class="o">.</span><span class="n">exp2</span><span class="p">(</span><span class="n">b_m</span> <span class="o">-</span> <span class="n">b_m_new</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-59'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-59'>#</a>
            </div>
            <p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.84444em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqcl" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.01968em">l</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.01968em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">←</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1.02998em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqz" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcy" style="">e</span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.87998em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqcm" style=""><span class="mord mathnormal mtight" style="">m</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3280857142857143em;"><span style="top:-2.357em;margin-left:0em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.143em;"><span></span></span></span></span></span></span></span><span class="mbin mtight" style="">−</span><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">m</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.7385428571428572em;"><span style="top:-2.214em;margin-left:0em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span style="top:-2.931em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord text mtight" style=""><span class="mord mtight" style="">n</span><span class="mord mtight" style=""><span class="mord mtight coloredeq eqcy" style="">e</span></span><span class="mord mtight" style="">w</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286em;"><span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span><span class="mord coloredeq eqcl" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.01968em">l</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.01968em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1.400382em;vertical-align:-0.43581800000000004em;"></span><span class="mord coloredeq eqx" style=""><span class="mop" style=""><span class="mop op-symbol small-op" style="position:relative;top:-0.0000050000000000050004em">∑</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.964564em;"><span style="top:-2.40029em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span><span class="mrel mtight" style="">=</span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span><span class="mord mtight" style=""><span class="mord mtight coloredeq eqcs" style="">1</span></span></span></span></span><span style="top:-3.2029000000000005em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span><span class="mord mtight" style=""><span class="mord mtight coloredeq eqct" style="">2</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.43581800000000004em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord accent" style=""><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.9201899999999998em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord mathnormal" style="margin-right:0.13889em">P</span></span><span style="top:-3.6023300000000003em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.16666em;"><span class="mord" style="">~</span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">534</span>        <span class="n">b_l</span> <span class="o">=</span> <span class="n">b_l</span> <span class="o">*</span> <span class="n">b_m_m_new</span> <span class="o">+</span> <span class="n">b_l_new</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-60'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-60'>#</a>
            </div>
            <p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.83333em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqci" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.02778em">O</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">←</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1.02998em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqz" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcy" style="">e</span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.87998em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqcm" style=""><span class="mord mathnormal mtight" style="">m</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3280857142857143em;"><span style="top:-2.357em;margin-left:0em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.143em;"><span></span></span></span></span></span></span></span><span class="mbin mtight" style="">−</span><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">m</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.7385428571428572em;"><span style="top:-2.214em;margin-left:0em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span style="top:-2.931em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord text mtight" style=""><span class="mord mtight" style="">n</span><span class="mord mtight" style=""><span class="mord mtight coloredeq eqcy" style="">e</span></span><span class="mord mtight" style="">w</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286em;"><span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span><span class="mord coloredeq eqci" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.02778em">O</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1.2062979999999999em;vertical-align:-0.286108em;"></span><span class="mord"><span class="mord accent"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.9201899999999998em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord mathnormal" style="margin-right:0.13889em;">P</span></span><span style="top:-3.6023300000000003em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.16666em;"><span class="mord">~</span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqcz" style=""><span class="mord mathnormal mtight" style="">i</span></span><span class="mord mtight coloredeq eqda" style=""><span class="mord mathnormal mtight" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span><span class="mord coloredeq eqck" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcx" style="margin-right:0.22222em">V</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">537</span>        <span class="n">b_o</span> <span class="o">=</span> <span class="n">b_o</span> <span class="o">*</span> <span class="n">b_m_m_new</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span>
<span class="lineno">538</span>        <span class="n">b_p</span> <span class="o">=</span> <span class="n">b_p</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">b_q</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>  <span class="c1"># TODO</span>
<span class="lineno">539</span>        <span class="n">b_v</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">p_v</span><span class="p">,</span> <span class="n">boundary_check</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,),</span> <span class="n">padding_option</span><span class="o">=</span><span class="s2">&quot;zero&quot;</span><span class="p">)</span>
<span class="lineno">540</span>        <span class="n">b_o</span> <span class="o">+=</span> <span class="n">tl</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">b_p</span><span class="p">,</span> <span class="n">b_v</span><span class="p">,</span> <span class="n">out_dtype</span><span class="o">=</span><span class="n">HI_PRES_TL</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-61'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-61'>#</a>
            </div>
            <p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mopen">(</span><span class="mord coloredeq eqbr" style=""><span class="mop" style=""><span class="mop" style=""><span style="">l</span><span style="">o</span><span style="margin-right:0.01389em">g</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.20696799999999996em;"><span style="top:-2.4558600000000004em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqct" style="">2</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.24414em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord mathnormal coloredeq eqcy" style="">e</span></span></span><span class="mclose">)</span><span class="mord coloredeq eqcm" style=""><span class="mord" style=""><span class="mord mathnormal" style="">m</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">←</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1.008664em;vertical-align:-0.258664em;"></span><span class="mopen">(</span><span class="mord coloredeq eqbr" style=""><span class="mop" style=""><span class="mop" style=""><span style="">l</span><span style="">o</span><span style="margin-right:0.01389em">g</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.20696799999999996em;"><span style="top:-2.4558600000000004em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqct" style="">2</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.24414em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord mathnormal coloredeq eqcy" style="">e</span></span></span><span class="mclose">)</span><span class="mord"><span class="mord mathnormal">m</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.664392em;"><span style="top:-2.441336em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqcz" style=""><span class="mord mathnormal mtight" style="">i</span></span></span></span></span><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">n</span><span class="mord mtight coloredeq eqcy" style=""><span class="mord mtight" style="">e</span></span><span class="mord mtight">w</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.258664em;"><span></span></span></span></span></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">543</span>        <span class="n">b_m</span> <span class="o">=</span> <span class="n">b_m_new</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-62'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-62'>#</a>
            </div>
            <p>Move pointers </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">546</span>        <span class="n">j</span> <span class="o">+=</span> <span class="n">BLOCK_K</span>
<span class="lineno">547</span>        <span class="n">p_v</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">advance</span><span class="p">(</span><span class="n">p_v</span><span class="p">,</span> <span class="p">(</span><span class="n">BLOCK_K</span><span class="p">,</span> <span class="mi">0</span><span class="p">))</span>
<span class="lineno">548</span>        <span class="n">p_kT</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">advance</span><span class="p">(</span><span class="n">p_kT</span><span class="p">,</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_K</span><span class="p">))</span>
<span class="lineno">549</span>
<span class="lineno">550</span>    <span class="n">tl</span><span class="o">.</span><span class="n">static_assert</span><span class="p">(</span><span class="n">b_o</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">HI_PRES_TL</span><span class="p">,</span> <span class="s2">&quot;attn_fwd_inner requires accumulator to be in HI_PRES_TL precision&quot;</span><span class="p">)</span>
<span class="lineno">551</span>
<span class="lineno">552</span>    <span class="k">return</span> <span class="n">b_o</span><span class="p">,</span> <span class="n">b_l</span><span class="p">,</span> <span class="n">b_m</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-63'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-63'>#</a>
            </div>
            <h4>Triton kernel to compute <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.83333em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqcf" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcu" style="margin-right:0.02778em">D</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span></h4>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">555</span><span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
<span class="lineno">556</span><span class="k">def</span> <span class="nf">_attn_bwd_d</span><span class="p">(</span><span class="n">t_o</span><span class="p">,</span> <span class="n">t_do</span><span class="p">,</span>
<span class="lineno">557</span>                <span class="n">t_pdp</span><span class="p">,</span>
<span class="lineno">558</span>                <span class="n">BLOCK_Q</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span> <span class="n">d_head</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
<span class="lineno">559</span>                <span class="n">q_seq_len</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
<span class="lineno">560</span>                <span class="n">n_groups</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
<span class="lineno">561</span>                <span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-64'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-64'>#</a>
            </div>
            
        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">565</span>    <span class="n">i</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">*</span> <span class="n">BLOCK_Q</span>
<span class="lineno">566</span>    <span class="n">z</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-65'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-65'>#</a>
            </div>
            <p>Create block pointers </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">569</span>    <span class="n">p_o</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">make_block_ptr</span><span class="p">(</span><span class="n">t_o</span> <span class="o">+</span> <span class="n">z</span> <span class="o">*</span> <span class="n">n_groups</span> <span class="o">*</span> <span class="n">q_seq_len</span> <span class="o">*</span> <span class="n">d_head</span><span class="p">,</span>
<span class="lineno">570</span>                            <span class="p">(</span><span class="n">n_groups</span><span class="p">,</span> <span class="n">q_seq_len</span><span class="p">,</span> <span class="n">d_head</span><span class="p">),</span>
<span class="lineno">571</span>                            <span class="p">(</span><span class="n">q_seq_len</span> <span class="o">*</span> <span class="n">d_head</span><span class="p">,</span> <span class="n">d_head</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
<span class="lineno">572</span>                            <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">i</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span>
<span class="lineno">573</span>                            <span class="p">(</span><span class="n">n_groups</span><span class="p">,</span> <span class="n">BLOCK_Q</span><span class="p">,</span> <span class="n">d_head</span><span class="p">),</span>
<span class="lineno">574</span>                            <span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">))</span>
<span class="lineno">575</span>    <span class="n">p_do</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">make_block_ptr</span><span class="p">(</span><span class="n">t_do</span> <span class="o">+</span> <span class="n">z</span> <span class="o">*</span> <span class="n">n_groups</span> <span class="o">*</span> <span class="n">q_seq_len</span> <span class="o">*</span> <span class="n">d_head</span><span class="p">,</span>
<span class="lineno">576</span>                             <span class="p">(</span><span class="n">n_groups</span><span class="p">,</span> <span class="n">q_seq_len</span><span class="p">,</span> <span class="n">d_head</span><span class="p">),</span>
<span class="lineno">577</span>                             <span class="p">(</span><span class="n">q_seq_len</span> <span class="o">*</span> <span class="n">d_head</span><span class="p">,</span> <span class="n">d_head</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
<span class="lineno">578</span>                             <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">i</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span>
<span class="lineno">579</span>                             <span class="p">(</span><span class="n">n_groups</span><span class="p">,</span> <span class="n">BLOCK_Q</span><span class="p">,</span> <span class="n">d_head</span><span class="p">),</span>
<span class="lineno">580</span>                             <span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">))</span>
<span class="lineno">581</span>    <span class="n">p_pdp</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">make_block_ptr</span><span class="p">(</span><span class="n">t_pdp</span> <span class="o">+</span> <span class="n">z</span> <span class="o">*</span> <span class="n">n_groups</span> <span class="o">*</span> <span class="n">q_seq_len</span><span class="p">,</span>
<span class="lineno">582</span>                              <span class="p">(</span><span class="n">n_groups</span><span class="p">,</span> <span class="n">q_seq_len</span><span class="p">),</span>
<span class="lineno">583</span>                              <span class="p">(</span><span class="n">q_seq_len</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
<span class="lineno">584</span>                              <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">i</span><span class="p">),</span>
<span class="lineno">585</span>                              <span class="p">(</span><span class="n">n_groups</span><span class="p">,</span> <span class="n">BLOCK_Q</span><span class="p">),</span>
<span class="lineno">586</span>                              <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">))</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-66'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-66'>#</a>
            </div>
            <p>Load <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.83333em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqci" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.02778em">O</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">589</span>    <span class="n">o</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">p_o</span><span class="p">,</span> <span class="n">boundary_check</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,),</span> <span class="n">padding_option</span><span class="o">=</span><span class="s2">&quot;zero&quot;</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-67'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-67'>#</a>
            </div>
            <p>Load <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.84444em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqcc" style=""><span class="mord mathnormal" style="">d</span><span class="mord" style=""><span class="mord coloredeq eqci" style=""><span class="mord mathnormal" style="margin-right:0.02778em">O</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">591</span>    <span class="n">do</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">p_do</span><span class="p">,</span> <span class="n">boundary_check</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,),</span> <span class="n">padding_option</span><span class="o">=</span><span class="s2">&quot;zero&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">HI_PRES_TL</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-68'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-68'>#</a>
            </div>
            <p>Calculate <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.83333em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqcf" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcu" style="margin-right:0.02778em">D</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1.0645609999999999em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqcc" style=""><span class="mord mathnormal" style="">d</span><span class="mord" style=""><span class="mord coloredeq eqci" style=""><span class="mord mathnormal" style="margin-right:0.02778em">O</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span><span class="mord"><span class="mord coloredeq eqci" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.02778em">O</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.914561em;"><span style="top:-3.1362300000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.13889em;">T</span></span></span></span></span></span></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">593</span>    <span class="n">d</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">o</span> <span class="o">*</span> <span class="n">do</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-69'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-69'>#</a>
            </div>
            <p>Save <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.83333em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqcf" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcu" style="margin-right:0.02778em">D</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">595</span>    <span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">p_pdp</span><span class="p">,</span> <span class="n">d</span><span class="p">,</span> <span class="n">boundary_check</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,))</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-70'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-70'>#</a>
            </div>
            <h4>Triton kernel to compute <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.980548em;vertical-align:-0.286108em;"></span><span class="mord coloredeq eqcb" style=""><span class="mord mathnormal" style="">d</span><span class="mord" style=""><span class="mord coloredeq eqcg" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcv" style="margin-right:0.07153em">K</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span></span></span></span></span> and <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.980548em;vertical-align:-0.286108em;"></span><span class="mord coloredeq eqce" style=""><span class="mord mathnormal" style="">d</span><span class="mord" style=""><span class="mord coloredeq eqck" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcx" style="margin-right:0.22222em">V</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span></span></span></span></span></h4>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">598</span><span class="nd">@triton</span><span class="o">.</span><span class="n">autotune</span><span class="p">(</span><span class="n">_get_autotune_configs</span><span class="p">(</span><span class="n">inner_loop</span><span class="o">=</span><span class="s1">&#39;query&#39;</span><span class="p">),</span>
<span class="lineno">599</span>                 <span class="n">key</span><span class="o">=</span><span class="p">[</span><span class="s2">&quot;q_seq_len&quot;</span><span class="p">,</span> <span class="s2">&quot;kv_seq_len&quot;</span><span class="p">,</span> <span class="s2">&quot;d_head&quot;</span><span class="p">,</span> <span class="s2">&quot;n_groups&quot;</span><span class="p">,</span> <span class="s2">&quot;is_causal&quot;</span><span class="p">])</span>
<span class="lineno">600</span><span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
<span class="lineno">601</span><span class="k">def</span> <span class="nf">_attn_bwd_dkdv</span><span class="p">(</span><span class="n">t_q</span><span class="p">,</span> <span class="n">t_k</span><span class="p">,</span> <span class="n">t_v</span><span class="p">,</span> <span class="n">sm_scale</span><span class="p">,</span>
<span class="lineno">602</span>                   <span class="n">t_do</span><span class="p">,</span>
<span class="lineno">603</span>                   <span class="n">t_dk</span><span class="p">,</span> <span class="n">t_dv</span><span class="p">,</span>
<span class="lineno">604</span>                   <span class="n">t_lse</span><span class="p">,</span> <span class="n">t_pdp</span><span class="p">,</span>
<span class="lineno">605</span>                   <span class="n">q_seq_len</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span> <span class="n">kv_seq_len</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
<span class="lineno">606</span>                   <span class="n">n_groups</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span> <span class="n">d_head</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
<span class="lineno">607</span>                   <span class="n">is_causal</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
<span class="lineno">608</span>                   <span class="n">BLOCK_Q</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
<span class="lineno">609</span>                   <span class="n">BLOCK_K</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
<span class="lineno">610</span>                   <span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-71'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-71'>#</a>
            </div>
            <p>Compute <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.980548em;vertical-align:-0.286108em;"></span><span class="mord coloredeq eqcb" style=""><span class="mord mathnormal" style="">d</span><span class="mord" style=""><span class="mord coloredeq eqcg" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcv" style="margin-right:0.07153em">K</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span></span></span></span></span> and <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.980548em;vertical-align:-0.286108em;"></span><span class="mord coloredeq eqce" style=""><span class="mord mathnormal" style="">d</span><span class="mord" style=""><span class="mord coloredeq eqck" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcx" style="margin-right:0.22222em">V</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span></span></span></span></span> for <code  class="highlight"><span></span><span class="n">j</span></code>
 ... <code  class="highlight"><span></span><span class="n">j</span> <span class="o">+</span> <span class="n">BLOCK_K</span></code>
 by iterating over <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8777699999999999em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqcj" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcw" style="">Q</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">616</span>    <span class="n">j</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">*</span> <span class="n">BLOCK_K</span>
<span class="lineno">617</span>    <span class="n">z</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-72'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-72'>#</a>
            </div>
            <p>Create block pointers </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">620</span>    <span class="n">p_k</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">make_block_ptr</span><span class="p">(</span><span class="n">t_k</span> <span class="o">+</span> <span class="n">z</span> <span class="o">*</span> <span class="n">kv_seq_len</span> <span class="o">*</span> <span class="n">d_head</span><span class="p">,</span>
<span class="lineno">621</span>                            <span class="p">(</span><span class="n">kv_seq_len</span><span class="p">,</span> <span class="n">d_head</span><span class="p">),</span>
<span class="lineno">622</span>                            <span class="p">(</span><span class="n">d_head</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
<span class="lineno">623</span>                            <span class="p">(</span><span class="n">j</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span>
<span class="lineno">624</span>                            <span class="p">(</span><span class="n">BLOCK_K</span><span class="p">,</span> <span class="n">d_head</span><span class="p">),</span>
<span class="lineno">625</span>                            <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">))</span>
<span class="lineno">626</span>    <span class="n">p_v</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">make_block_ptr</span><span class="p">(</span><span class="n">t_v</span> <span class="o">+</span> <span class="n">z</span> <span class="o">*</span> <span class="n">kv_seq_len</span> <span class="o">*</span> <span class="n">d_head</span><span class="p">,</span>
<span class="lineno">627</span>                            <span class="p">(</span><span class="n">kv_seq_len</span><span class="p">,</span> <span class="n">d_head</span><span class="p">),</span>
<span class="lineno">628</span>                            <span class="p">(</span><span class="n">d_head</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
<span class="lineno">629</span>                            <span class="p">(</span><span class="n">j</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span>
<span class="lineno">630</span>                            <span class="p">(</span><span class="n">BLOCK_K</span><span class="p">,</span> <span class="n">d_head</span><span class="p">),</span>
<span class="lineno">631</span>                            <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">))</span>
<span class="lineno">632</span>    <span class="n">p_dk</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">make_block_ptr</span><span class="p">(</span><span class="n">t_dk</span> <span class="o">+</span> <span class="n">z</span> <span class="o">*</span> <span class="n">kv_seq_len</span> <span class="o">*</span> <span class="n">d_head</span><span class="p">,</span>
<span class="lineno">633</span>                             <span class="p">(</span><span class="n">kv_seq_len</span><span class="p">,</span> <span class="n">d_head</span><span class="p">),</span>
<span class="lineno">634</span>                             <span class="p">(</span><span class="n">d_head</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
<span class="lineno">635</span>                             <span class="p">(</span><span class="n">j</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span>
<span class="lineno">636</span>                             <span class="p">(</span><span class="n">BLOCK_K</span><span class="p">,</span> <span class="n">d_head</span><span class="p">),</span>
<span class="lineno">637</span>                             <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">))</span>
<span class="lineno">638</span>    <span class="n">p_dv</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">make_block_ptr</span><span class="p">(</span><span class="n">t_dv</span> <span class="o">+</span> <span class="n">z</span> <span class="o">*</span> <span class="n">kv_seq_len</span> <span class="o">*</span> <span class="n">d_head</span><span class="p">,</span>
<span class="lineno">639</span>                             <span class="p">(</span><span class="n">kv_seq_len</span><span class="p">,</span> <span class="n">d_head</span><span class="p">),</span>
<span class="lineno">640</span>                             <span class="p">(</span><span class="n">d_head</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
<span class="lineno">641</span>                             <span class="p">(</span><span class="n">j</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span>
<span class="lineno">642</span>                             <span class="p">(</span><span class="n">BLOCK_K</span><span class="p">,</span> <span class="n">d_head</span><span class="p">),</span>
<span class="lineno">643</span>                             <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">))</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-73'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-73'>#</a>
            </div>
            <p>Initialize <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.190108em;vertical-align:-0.345em;"></span><span class="mord coloredeq eqbh" style=""><span class="mord" style=""><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.845108em;"><span style="top:-2.6550000000000002em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqbv" style="margin-right:0.03588em">σ</span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em"></span></span><span style="top:-3.394em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqcs" style="">1</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.345em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mord" style=""><span class="mord mathnormal coloredeq eqcn" style="">d</span><span class="mord coloredeq eqcn" style=""><span class="mord mathnormal coloredeq eqcv" style="margin-right:0.07153em">K</span></span></span></span></span></span></span></span> and <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord coloredeq eqcq" style=""><span class="mord mathnormal" style="">d</span><span class="mord" style=""><span class="mord mathnormal coloredeq eqcx" style="margin-right:0.22222em">V</span></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">646</span>    <span class="n">b_dk</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="n">BLOCK_K</span><span class="p">,</span> <span class="n">d_head</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">HI_PRES_TL</span><span class="p">)</span>
<span class="lineno">647</span>    <span class="n">b_dv</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="n">BLOCK_K</span><span class="p">,</span> <span class="n">d_head</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">HI_PRES_TL</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-74'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-74'>#</a>
            </div>
            <p>Load <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.1764999999999999em;vertical-align:-0.481108em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.695392em;"><span style="top:-2.6550000000000002em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mop mtight"><span class="mtight">l</span><span class="mtight">o</span><span class="mtight" style="margin-right:0.01389em;">g</span></span><span class="mspace mtight" style="margin-right:0.19516666666666668em;"></span><span class="mord mtight coloredeq eqct" style=""><span class="mord mtight" style="">2</span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.394em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqbv" style=""><span class="mord mathnormal mtight" style="margin-right:0.03588em">σ</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.481108em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mord coloredeq eqcv" style=""><span class="mord mathnormal" style="margin-right:0.07153em">K</span></span></span></span></span></span> and <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqcx" style=""><span class="mord mathnormal" style="margin-right:0.22222em">V</span></span></span></span></span></span> outside the loop. </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">650</span>    <span class="n">b_k</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">p_k</span><span class="p">,</span> <span class="n">boundary_check</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,),</span> <span class="n">padding_option</span><span class="o">=</span><span class="s2">&quot;zero&quot;</span><span class="p">)</span>
<span class="lineno">651</span>    <span class="n">b_v</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">p_v</span><span class="p">,</span> <span class="n">boundary_check</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,),</span> <span class="n">padding_option</span><span class="o">=</span><span class="s2">&quot;zero&quot;</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-75'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-75'>#</a>
            </div>
            <p>Iterate through queries in GQA </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">654</span>    <span class="k">for</span> <span class="n">g</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_groups</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-76'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-76'>#</a>
            </div>
            <p>Create block pointers </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">656</span>        <span class="n">p_qT</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">make_block_ptr</span><span class="p">(</span><span class="n">t_q</span> <span class="o">+</span> <span class="n">z</span> <span class="o">*</span> <span class="n">n_groups</span> <span class="o">*</span> <span class="n">q_seq_len</span> <span class="o">*</span> <span class="n">d_head</span> <span class="o">+</span> <span class="n">g</span> <span class="o">*</span> <span class="n">q_seq_len</span> <span class="o">*</span> <span class="n">d_head</span><span class="p">,</span>
<span class="lineno">657</span>                                 <span class="p">(</span><span class="n">d_head</span><span class="p">,</span> <span class="n">q_seq_len</span><span class="p">),</span>
<span class="lineno">658</span>                                 <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">d_head</span><span class="p">),</span>
<span class="lineno">659</span>                                 <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span>
<span class="lineno">660</span>                                 <span class="p">(</span><span class="n">d_head</span><span class="p">,</span> <span class="n">BLOCK_Q</span><span class="p">),</span>
<span class="lineno">661</span>                                 <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
<span class="lineno">662</span>
<span class="lineno">663</span>        <span class="n">p_do</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">make_block_ptr</span><span class="p">(</span><span class="n">t_do</span> <span class="o">+</span> <span class="n">z</span> <span class="o">*</span> <span class="n">n_groups</span> <span class="o">*</span> <span class="n">q_seq_len</span> <span class="o">*</span> <span class="n">d_head</span> <span class="o">+</span> <span class="n">g</span> <span class="o">*</span> <span class="n">q_seq_len</span> <span class="o">*</span> <span class="n">d_head</span><span class="p">,</span>
<span class="lineno">664</span>                                 <span class="p">(</span><span class="n">q_seq_len</span><span class="p">,</span> <span class="n">d_head</span><span class="p">),</span>
<span class="lineno">665</span>                                 <span class="p">(</span><span class="n">d_head</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
<span class="lineno">666</span>                                 <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span>
<span class="lineno">667</span>                                 <span class="p">(</span><span class="n">BLOCK_Q</span><span class="p">,</span> <span class="n">d_head</span><span class="p">),</span>
<span class="lineno">668</span>                                 <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">))</span>
<span class="lineno">669</span>        <span class="n">p_lse</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">make_block_ptr</span><span class="p">(</span><span class="n">t_lse</span> <span class="o">+</span> <span class="n">z</span> <span class="o">*</span> <span class="n">n_groups</span> <span class="o">*</span> <span class="n">q_seq_len</span> <span class="o">+</span> <span class="n">g</span> <span class="o">*</span> <span class="n">q_seq_len</span><span class="p">,</span>
<span class="lineno">670</span>                                  <span class="p">(</span><span class="n">q_seq_len</span><span class="p">,),</span>
<span class="lineno">671</span>                                  <span class="p">(</span><span class="mi">1</span><span class="p">,),</span>
<span class="lineno">672</span>                                  <span class="p">(</span><span class="mi">0</span><span class="p">,),</span>
<span class="lineno">673</span>                                  <span class="p">(</span><span class="n">BLOCK_Q</span><span class="p">,),</span>
<span class="lineno">674</span>                                  <span class="p">(</span><span class="mi">0</span><span class="p">,))</span>
<span class="lineno">675</span>        <span class="n">p_pdp</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">make_block_ptr</span><span class="p">(</span><span class="n">t_pdp</span> <span class="o">+</span> <span class="n">z</span> <span class="o">*</span> <span class="n">n_groups</span> <span class="o">*</span> <span class="n">q_seq_len</span> <span class="o">+</span> <span class="n">g</span> <span class="o">*</span> <span class="n">q_seq_len</span><span class="p">,</span>
<span class="lineno">676</span>                                  <span class="p">(</span><span class="n">q_seq_len</span><span class="p">,),</span>
<span class="lineno">677</span>                                  <span class="p">(</span><span class="mi">1</span><span class="p">,),</span>
<span class="lineno">678</span>                                  <span class="p">(</span><span class="mi">0</span><span class="p">,),</span>
<span class="lineno">679</span>                                  <span class="p">(</span><span class="n">BLOCK_Q</span><span class="p">,),</span>
<span class="lineno">680</span>                                  <span class="p">(</span><span class="mi">0</span><span class="p">,))</span>
<span class="lineno">681</span>
<span class="lineno">682</span>        <span class="k">if</span> <span class="n">is_causal</span><span class="p">:</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-77'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-77'>#</a>
            </div>
            <p>Inner loop at the diagonal block </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">684</span>            <span class="n">b_dk</span><span class="p">,</span> <span class="n">b_dv</span> <span class="o">=</span> <span class="n">_attn_bwd_dkdv_inner</span><span class="p">(</span>
<span class="lineno">685</span>                <span class="n">b_dk</span><span class="p">,</span> <span class="n">b_dv</span><span class="p">,</span>
<span class="lineno">686</span>                <span class="n">p_qT</span><span class="p">,</span> <span class="n">b_k</span><span class="p">,</span> <span class="n">b_v</span><span class="p">,</span> <span class="n">p_do</span><span class="p">,</span>
<span class="lineno">687</span>                <span class="n">p_lse</span><span class="p">,</span> <span class="n">p_pdp</span><span class="p">,</span>
<span class="lineno">688</span>                <span class="n">BLOCK_Q</span><span class="p">,</span> <span class="n">BLOCK_K</span><span class="p">,</span>
<span class="lineno">689</span>                <span class="n">d_head</span><span class="p">,</span>
<span class="lineno">690</span>                <span class="n">j</span><span class="o">=</span><span class="n">j</span><span class="p">,</span> <span class="n">i</span><span class="o">=</span><span class="n">j</span><span class="p">,</span>
<span class="lineno">691</span>                <span class="n">steps</span><span class="o">=</span><span class="n">BLOCK_K</span> <span class="o">//</span> <span class="n">BLOCK_Q</span><span class="p">,</span>
<span class="lineno">692</span>                <span class="n">MASK</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="lineno">693</span>                <span class="n">q_seq_len</span><span class="o">=</span><span class="n">q_seq_len</span><span class="p">,</span>
<span class="lineno">694</span>                <span class="n">kv_seq_len</span><span class="o">=</span><span class="n">kv_seq_len</span><span class="p">,</span>
<span class="lineno">695</span>            <span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-78'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-78'>#</a>
            </div>
            <p>Inner loop on queries after the diagonal </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">698</span>            <span class="n">b_dk</span><span class="p">,</span> <span class="n">b_dv</span> <span class="o">=</span> <span class="n">_attn_bwd_dkdv_inner</span><span class="p">(</span>
<span class="lineno">699</span>                <span class="n">b_dk</span><span class="p">,</span> <span class="n">b_dv</span><span class="p">,</span>
<span class="lineno">700</span>                <span class="n">p_qT</span><span class="p">,</span> <span class="n">b_k</span><span class="p">,</span> <span class="n">b_v</span><span class="p">,</span> <span class="n">p_do</span><span class="p">,</span>
<span class="lineno">701</span>                <span class="n">p_lse</span><span class="p">,</span> <span class="n">p_pdp</span><span class="p">,</span>
<span class="lineno">702</span>                <span class="n">BLOCK_Q</span><span class="p">,</span> <span class="n">BLOCK_K</span><span class="p">,</span>
<span class="lineno">703</span>                <span class="n">d_head</span><span class="p">,</span>
<span class="lineno">704</span>                <span class="n">j</span><span class="o">=</span><span class="n">j</span><span class="p">,</span> <span class="n">i</span><span class="o">=</span><span class="n">j</span> <span class="o">+</span> <span class="n">BLOCK_K</span><span class="p">,</span>
<span class="lineno">705</span>                <span class="n">steps</span><span class="o">=</span><span class="n">tl</span><span class="o">.</span><span class="n">cdiv</span><span class="p">((</span><span class="n">q_seq_len</span> <span class="o">-</span> <span class="p">(</span><span class="n">j</span> <span class="o">+</span> <span class="n">BLOCK_K</span><span class="p">)),</span> <span class="n">BLOCK_Q</span><span class="p">),</span>
<span class="lineno">706</span>                <span class="n">MASK</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="lineno">707</span>                <span class="n">q_seq_len</span><span class="o">=</span><span class="n">q_seq_len</span><span class="p">,</span>
<span class="lineno">708</span>                <span class="n">kv_seq_len</span><span class="o">=</span><span class="n">kv_seq_len</span>
<span class="lineno">709</span>            <span class="p">)</span>
<span class="lineno">710</span>        <span class="k">else</span><span class="p">:</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-79'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-79'>#</a>
            </div>
            <p>Iterate through all queries </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">712</span>            <span class="n">b_dk</span><span class="p">,</span> <span class="n">b_dv</span> <span class="o">=</span> <span class="n">_attn_bwd_dkdv_inner</span><span class="p">(</span>
<span class="lineno">713</span>                <span class="n">b_dk</span><span class="p">,</span> <span class="n">b_dv</span><span class="p">,</span>
<span class="lineno">714</span>                <span class="n">p_qT</span><span class="p">,</span> <span class="n">b_k</span><span class="p">,</span> <span class="n">b_v</span><span class="p">,</span> <span class="n">p_do</span><span class="p">,</span>
<span class="lineno">715</span>                <span class="n">p_lse</span><span class="p">,</span> <span class="n">p_pdp</span><span class="p">,</span>
<span class="lineno">716</span>                <span class="n">BLOCK_Q</span><span class="p">,</span> <span class="n">BLOCK_K</span><span class="p">,</span>
<span class="lineno">717</span>                <span class="n">d_head</span><span class="p">,</span>
<span class="lineno">718</span>                <span class="n">j</span><span class="o">=</span><span class="n">j</span><span class="p">,</span> <span class="n">i</span><span class="o">=</span><span class="n">tl</span><span class="o">.</span><span class="n">full</span><span class="p">([],</span> <span class="mi">0</span><span class="p">,</span> <span class="n">tl</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="lineno">719</span>                <span class="n">steps</span><span class="o">=</span><span class="n">tl</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">q_seq_len</span><span class="p">,</span> <span class="n">BLOCK_Q</span><span class="p">),</span>
<span class="lineno">720</span>                <span class="n">MASK</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="lineno">721</span>                <span class="n">q_seq_len</span><span class="o">=</span><span class="n">q_seq_len</span><span class="p">,</span>
<span class="lineno">722</span>                <span class="n">kv_seq_len</span><span class="o">=</span><span class="n">kv_seq_len</span>
<span class="lineno">723</span>            <span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-80'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-80'>#</a>
            </div>
            <p>Save <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord coloredeq eqcq" style=""><span class="mord mathnormal" style="">d</span><span class="mord" style=""><span class="mord mathnormal coloredeq eqcx" style="margin-right:0.22222em">V</span></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">726</span>    <span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">p_dv</span><span class="p">,</span> <span class="n">b_dv</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">t_dv</span><span class="o">.</span><span class="n">type</span><span class="o">.</span><span class="n">element_ty</span><span class="p">),</span> <span class="n">boundary_check</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,))</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-81'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-81'>#</a>
            </div>
            <p><code  class="highlight"><span></span><span class="n">b_dk</span></code>
 had <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.190108em;vertical-align:-0.345em;"></span><span class="mord coloredeq eqbh" style=""><span class="mord" style=""><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.845108em;"><span style="top:-2.6550000000000002em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqbv" style="margin-right:0.03588em">σ</span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em"></span></span><span style="top:-3.394em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqcs" style="">1</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.345em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mord" style=""><span class="mord mathnormal coloredeq eqcn" style="">d</span><span class="mord coloredeq eqcn" style=""><span class="mord mathnormal coloredeq eqcv" style="margin-right:0.07153em">K</span></span></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">729</span>    <span class="n">b_dk</span> <span class="o">*=</span> <span class="n">sm_scale</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-82'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-82'>#</a>
            </div>
            <p>Save <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord coloredeq eqcn" style=""><span class="mord mathnormal" style="">d</span><span class="mord" style=""><span class="mord mathnormal coloredeq eqcv" style="margin-right:0.07153em">K</span></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">732</span>    <span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">p_dk</span><span class="p">,</span> <span class="n">b_dk</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">t_dk</span><span class="o">.</span><span class="n">type</span><span class="o">.</span><span class="n">element_ty</span><span class="p">),</span> <span class="n">boundary_check</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,))</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-83'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-83'>#</a>
            </div>
            <h4>Inner loop to calculate <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.980548em;vertical-align:-0.286108em;"></span><span class="mord coloredeq eqcb" style=""><span class="mord mathnormal" style="">d</span><span class="mord" style=""><span class="mord coloredeq eqcg" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcv" style="margin-right:0.07153em">K</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span></span></span></span></span>, <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.980548em;vertical-align:-0.286108em;"></span><span class="mord coloredeq eqce" style=""><span class="mord mathnormal" style="">d</span><span class="mord" style=""><span class="mord coloredeq eqck" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcx" style="margin-right:0.22222em">V</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span></span></span></span></span></h4>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">735</span><span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
<span class="lineno">736</span><span class="k">def</span> <span class="nf">_attn_bwd_dkdv_inner</span><span class="p">(</span><span class="n">b_dk</span><span class="p">,</span> <span class="n">b_dv</span><span class="p">,</span>
<span class="lineno">737</span>                         <span class="n">p_qT</span><span class="p">,</span> <span class="n">b_k</span><span class="p">,</span> <span class="n">b_v</span><span class="p">,</span> <span class="n">p_do</span><span class="p">,</span>
<span class="lineno">738</span>                         <span class="n">p_lse</span><span class="p">,</span> <span class="n">p_pdp</span><span class="p">,</span>
<span class="lineno">739</span>                         <span class="n">BLOCK_Q</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span> <span class="n">BLOCK_K</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
<span class="lineno">740</span>                         <span class="n">d_head</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
<span class="lineno">741</span>                         <span class="n">j</span><span class="p">,</span> <span class="n">i</span><span class="p">,</span> <span class="n">steps</span><span class="p">,</span>
<span class="lineno">742</span>                         <span class="n">MASK</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
<span class="lineno">743</span>                         <span class="n">q_seq_len</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
<span class="lineno">744</span>                         <span class="n">kv_seq_len</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-84'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-84'>#</a>
            </div>
            <p>To apply the mask </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">750</span>    <span class="n">tl</span><span class="o">.</span><span class="n">static_assert</span><span class="p">(</span><span class="n">BLOCK_K</span> <span class="o">%</span> <span class="n">BLOCK_Q</span> <span class="o">==</span> <span class="mi">0</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-85'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-85'>#</a>
            </div>
            <p>Offsets and mask </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">753</span>    <span class="n">offs_i</span> <span class="o">=</span> <span class="n">i</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_Q</span><span class="p">)</span>
<span class="lineno">754</span>    <span class="n">offs_j</span> <span class="o">=</span> <span class="n">j</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_K</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-86'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-86'>#</a>
            </div>
            <p>Move the pointers </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">757</span>    <span class="n">p_qT</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">advance</span><span class="p">(</span><span class="n">p_qT</span><span class="p">,</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">i</span><span class="p">))</span>
<span class="lineno">758</span>    <span class="n">p_do</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">advance</span><span class="p">(</span><span class="n">p_do</span><span class="p">,</span> <span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="mi">0</span><span class="p">))</span>
<span class="lineno">759</span>    <span class="n">p_lse</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">advance</span><span class="p">(</span><span class="n">p_lse</span><span class="p">,</span> <span class="p">(</span><span class="n">i</span><span class="p">,))</span>
<span class="lineno">760</span>    <span class="n">p_pdp</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">advance</span><span class="p">(</span><span class="n">p_pdp</span><span class="p">,</span> <span class="p">(</span><span class="n">i</span><span class="p">,))</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-87'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-87'>#</a>
            </div>
            <p>Iterate over <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8777699999999999em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqcw" style=""><span class="mord mathnormal" style="">Q</span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">763</span>    <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">steps</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-88'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-88'>#</a>
            </div>
            <p>Load <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.109001em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqby" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord coloredeq eqcj" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcw" style="">Q</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.914561em;"><span style="top:-3.1362300000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.13889em">T</span></span></span></span></span></span></span></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">765</span>        <span class="n">b_qT</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">p_qT</span><span class="p">,</span> <span class="n">boundary_check</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,),</span> <span class="n">padding_option</span><span class="o">=</span><span class="s2">&quot;zero&quot;</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-89'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-89'>#</a>
            </div>
            <p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord mathnormal" style="margin-right:0.01968em;">l</span><span class="mord mathnormal">o</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">g</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight coloredeq eqct" style=""><span class="mord mtight" style="">2</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mord coloredeq eqch" style=""><span class="mord" style=""><span class="mord mathnormal" style="">L</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">768</span>        <span class="n">b_l</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">p_lse</span><span class="p">,</span> <span class="n">boundary_check</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,),</span> <span class="n">padding_option</span><span class="o">=</span><span class="s2">&quot;zero&quot;</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-90'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-90'>#</a>
            </div>
            <p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.200669em;vertical-align:-0.286108em;"></span><span class="mopen">(</span><span class="mord coloredeq eqbr" style=""><span class="mop" style=""><span class="mop" style=""><span style="">l</span><span style="">o</span><span style="margin-right:0.01389em">g</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.20696799999999996em;"><span style="top:-2.4558600000000004em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqct" style="">2</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.24414em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord mathnormal coloredeq eqcy" style="">e</span></span></span><span class="mclose">)</span><span class="mord"><span class="mord coloredeq eqbu" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.05764em">S</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.05764em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.914561em;"><span style="top:-3.1362300000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.13889em;">T</span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1.200669em;vertical-align:-0.286108em;"></span><span class="mord coloredeq eqbe" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqbv" style="margin-right:0.03588em">σ</span></span><span class="mopen" style="">(</span><span class="mord" style=""><span class="mop coloredeq eqbr" style=""><span class="mop" style=""><span style="">l</span><span style="">o</span><span style="margin-right:0.01389em">g</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.20696799999999996em;"><span style="top:-2.4558600000000004em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqct" style="">2</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.24414em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em"></span><span class="mord coloredeq eqbr" style=""><span class="mord mathnormal coloredeq eqcy" style="">e</span></span></span><span class="mclose" style="">)</span><span class="mord" style=""><span class="mord coloredeq eqcg" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcv" style="margin-right:0.07153em">K</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span><span class="mord coloredeq eqby" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord coloredeq eqcj" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcw" style="">Q</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.914561em;"><span style="top:-3.1362300000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.13889em">T</span></span></span></span></span></span></span></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">771</span>        <span class="n">b_sT</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">b_k</span><span class="p">,</span> <span class="n">b_qT</span><span class="p">,</span> <span class="n">out_dtype</span><span class="o">=</span><span class="n">HI_PRES_TL</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-91'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-91'>#</a>
            </div>
            <span ><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:6.803331em;vertical-align:-3.1516655em;"></span><span class="mord coloredeq eqd" style=""><span class="mord" style=""><span class="mtable"><span class="col-align-r"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:3.6516655em;"><span style="top:-5.6983345000000005em;"><span class="pstrut" style="height:3.565em;"></span><span class="mord" style=""><span class="mord" style=""><span class="mord coloredeq eqbt" style=""><span class="mord mathnormal" style="margin-right:0.13889em">P</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span></span><span style="top:-2.997334499999999em;"><span class="pstrut" style="height:3.565em;"></span><span class="mord" style=""></span></span><span style="top:-1.0733345em;"><span class="pstrut" style="height:3.565em;"></span><span class="mord" style=""></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:3.1516655em;"><span></span></span></span></span></span><span class="col-align-l"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:3.6516655em;"><span style="top:-5.6983345000000005em;"><span class="pstrut" style="height:3.565em;"></span><span class="mord" style=""><span class="mord" style=""></span><span class="mspace" style="margin-right:0.2777777777777778em"></span><span class="mrel" style="">=</span><span class="mspace" style="margin-right:0.2777777777777778em"></span><span class="mord" style=""><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.518331em;"><span style="top:-2.3139999999999996em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style=""><span class="mord" style=""><span class="mord coloredeq eqch" style=""><span class="mord mathnormal" style="">L</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em"></span></span><span style="top:-3.677em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcy" style="">e</span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.841331em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqbu" style=""><span class="mord mathnormal mtight" style="margin-right:0.05764em">S</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3280857142857143em;"><span style="top:-2.357em;margin-left:-0.05764em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2818857142857143em;"><span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.8360000000000001em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span><span style="top:-2.997334499999999em;"><span class="pstrut" style="height:3.565em;"></span><span class="mord" style=""><span class="mord" style=""></span><span class="mspace" style="margin-right:0.2777777777777778em"></span><span class="mrel" style="">=</span><span class="mspace" style="margin-right:0.2777777777777778em"></span><span class="mord" style=""><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.565em;"><span style="top:-2.314em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord coloredeq eqct" style="">2</span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.7799659999999999em;"><span style="top:-2.9938580000000004em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mop mtight coloredeq eqbp" style=""><span class="mop mtight" style=""><span class="mtight" style="">l</span><span class="mtight" style="">o</span><span class="mtight" style="margin-right:0.01389em">g</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.19444571428571428em;"><span style="top:-2.2341314285714287em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqct" style="">2</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.26586857142857145em;"><span></span></span></span></span></span></span><span class="mspace mtight" style="margin-right:0.19516666666666668em"></span><span class="mord mtight coloredeq eqbp" style=""><span class="mord mtight coloredeq eqch" style=""><span class="mord mathnormal mtight" style="">L</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3280857142857143em;"><span style="top:-2.357em;margin-left:0em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.143em;"><span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em"></span></span><span style="top:-3.677em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord coloredeq eqct" style="">2</span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8879999999999999em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mopen mtight" style="">(</span><span class="mord mathnormal mtight" style="margin-right:0.01968em">l</span><span class="mord mathnormal mtight" style="">o</span><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.03588em">g</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31731428571428577em;"><span style="top:-2.357em;margin-left:-0.03588em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqct" style="">2</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.143em;"><span></span></span></span></span></span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcy" style="">e</span></span><span class="mclose mtight" style="">)</span><span class="mord mtight" style=""><span class="mord mtight coloredeq eqbu" style=""><span class="mord mathnormal mtight" style="margin-right:0.05764em">S</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3280857142857143em;"><span style="top:-2.357em;margin-left:-0.05764em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2818857142857143em;"><span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.686em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span><span style="top:-1.0733345em;"><span class="pstrut" style="height:3.565em;"></span><span class="mord" style=""><span class="mord" style=""></span><span class="mspace" style="margin-right:0.2777777777777778em"></span><span class="mrel" style="">=</span><span class="mspace" style="margin-right:0.2777777777777778em"></span><span class="mord" style=""><span class="mord" style=""><span class="mord coloredeq eqct" style="">2</span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.9379999999999998em;"><span style="top:-3.113em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mopen mtight" style="">(</span><span class="mord mathnormal mtight" style="margin-right:0.01968em">l</span><span class="mord mathnormal mtight" style="">o</span><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.03588em">g</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31731428571428577em;"><span style="top:-2.357em;margin-left:-0.03588em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqct" style="">2</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.143em;"><span></span></span></span></span></span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcy" style="">e</span></span><span class="mclose mtight" style="">)</span><span class="mord mtight" style=""><span class="mord mtight coloredeq eqbu" style=""><span class="mord mathnormal mtight" style="margin-right:0.05764em">S</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3280857142857143em;"><span style="top:-2.357em;margin-left:-0.05764em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2818857142857143em;"><span></span></span></span></span></span></span></span><span class="mbin mtight" style="">−</span><span class="mord mtight" style=""><span class="mop mtight coloredeq eqbp" style=""><span class="mop mtight" style=""><span class="mtight" style="">l</span><span class="mtight" style="">o</span><span class="mtight" style="margin-right:0.01389em">g</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.19444571428571428em;"><span style="top:-2.2341314285714287em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqct" style="">2</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.26586857142857145em;"><span></span></span></span></span></span></span><span class="mspace mtight" style="margin-right:0.19516666666666668em"></span><span class="mord mtight coloredeq eqbp" style=""><span class="mord mtight coloredeq eqch" style=""><span class="mord mathnormal mtight" style="">L</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3280857142857143em;"><span style="top:-2.357em;margin-left:0em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.143em;"><span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:3.1516655em;"><span></span></span></span></span></span></span></span></span></span></span></span></span></span><p> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">780</span>        <span class="n">b_pT</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">math</span><span class="o">.</span><span class="n">exp2</span><span class="p">(</span><span class="n">b_sT</span> <span class="o">-</span> <span class="n">b_l</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:])</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-92'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-92'>#</a>
            </div>
            <p>Autoregressive masking </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">783</span>        <span class="k">if</span> <span class="n">MASK</span><span class="p">:</span>
<span class="lineno">784</span>            <span class="n">mask</span> <span class="o">=</span> <span class="p">(</span><span class="n">offs_i</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o">&gt;=</span> <span class="n">offs_j</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">])</span>
<span class="lineno">785</span>            <span class="n">b_pT</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="n">b_pT</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-93'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-93'>#</a>
            </div>
            <p>Mask out if the block is beyond the end of <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8777699999999999em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqcj" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcw" style="">Q</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span></p>
<p>Note: No need to mask out based on <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.85396em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqda" style=""><span class="mord mathnormal" style="margin-right:0.05724em">j</span></span></span></span></span></span> because the effects on positions outside boundary will not get stored in <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord coloredeq eqcn" style=""><span class="mord mathnormal" style="">d</span><span class="mord" style=""><span class="mord mathnormal coloredeq eqcv" style="margin-right:0.07153em">K</span></span></span></span></span></span></span> or <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord coloredeq eqcq" style=""><span class="mord mathnormal" style="">d</span><span class="mord" style=""><span class="mord mathnormal coloredeq eqcx" style="margin-right:0.22222em">V</span></span></span></span></span></span></span> Masking by <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.65952em;vertical-align:0em;"></span><span class="mord coloredeq eqcz" style=""><span class="mord mathnormal" style="">i</span></span></span></span></span></span> may also not be necessary size the tensors have 0 on loading </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">792</span>        <span class="n">i_mask</span> <span class="o">=</span> <span class="n">offs_i</span> <span class="o">&lt;</span> <span class="n">q_seq_len</span>
<span class="lineno">793</span>        <span class="n">b_pT</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">i_mask</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:],</span> <span class="n">b_pT</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-94'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-94'>#</a>
            </div>
            <p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.980548em;vertical-align:-0.286108em;"></span><span class="mord coloredeq eqce" style=""><span class="mord mathnormal" style="">d</span><span class="mord" style=""><span class="mord coloredeq eqck" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcx" style="margin-right:0.22222em">V</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1.0497100000000001em;vertical-align:-0.29971000000000003em;"></span><span class="mop"><span class="mop op-symbol small-op" style="position:relative;top:-0.0000050000000000050004em;">∑</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.16195399999999993em;"><span style="top:-2.40029em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight coloredeq eqcz" style=""><span class="mord mathnormal mtight" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.29971000000000003em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqbt" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.13889em">P</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span><span class="mord coloredeq eqcc" style=""><span class="mord mathnormal" style="">d</span><span class="mord" style=""><span class="mord coloredeq eqci" style=""><span class="mord mathnormal" style="margin-right:0.02778em">O</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">796</span>        <span class="n">b_do</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">p_do</span><span class="p">,</span> <span class="n">boundary_check</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,),</span> <span class="n">padding_option</span><span class="o">=</span><span class="s2">&quot;zero&quot;</span><span class="p">)</span>
<span class="lineno">797</span>        <span class="n">b_dv</span> <span class="o">+=</span> <span class="n">tl</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">b_pT</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">b_do</span><span class="o">.</span><span class="n">dtype</span><span class="p">),</span> <span class="n">b_do</span><span class="p">,</span> <span class="n">out_dtype</span><span class="o">=</span><span class="n">HI_PRES_TL</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-95'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-95'>#</a>
            </div>
            <p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.83333em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqcf" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcu" style="margin-right:0.02778em">D</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">800</span>        <span class="n">b_pdp</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">p_pdp</span><span class="p">,</span> <span class="n">boundary_check</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,),</span> <span class="n">padding_option</span><span class="o">=</span><span class="s2">&quot;zero&quot;</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-96'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-96'>#</a>
            </div>
            <p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.980548em;vertical-align:-0.286108em;"></span><span class="mord mathnormal">d</span><span class="mord coloredeq eqbt" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.13889em">P</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1.211779em;vertical-align:-0.286108em;"></span><span class="mord coloredeq eqck" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcx" style="margin-right:0.22222em">V</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span><span class="mord"><span class="mord coloredeq eqcc" style=""><span class="mord mathnormal" style="">d</span><span class="mord" style=""><span class="mord coloredeq eqci" style=""><span class="mord mathnormal" style="margin-right:0.02778em">O</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.9256709999999999em;"><span style="top:-3.1473400000000002em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.13889em;">T</span></span></span></span></span></span></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">802</span>        <span class="n">b_dpT</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">b_v</span><span class="p">,</span> <span class="n">tl</span><span class="o">.</span><span class="n">trans</span><span class="p">(</span><span class="n">b_do</span><span class="p">),</span> <span class="n">out_dtype</span><span class="o">=</span><span class="n">HI_PRES_TL</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">HI_PRES_TL</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-97'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-97'>#</a>
            </div>
            <p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.20001em;vertical-align:-0.35001em;"></span><span class="mord coloredeq eqt" style=""><span class="mord mathnormal" style="">d</span><span class="mord" style=""><span class="mord coloredeq eqbu" style=""><span class="mord mathnormal" style="margin-right:0.05764em">S</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.05764em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel" style="">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mord" style=""><span class="mord coloredeq eqbt" style=""><span class="mord mathnormal" style="margin-right:0.13889em">P</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span><span class="mord" style=""><span class="delimsizing size1" style=""><span style="">(</span></span></span><span class="mord mathnormal" style="">d</span><span class="mord" style=""><span class="mord coloredeq eqbt" style=""><span class="mord mathnormal" style="margin-right:0.13889em">P</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin" style="">−</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord" style=""><span class="mord coloredeq eqcf" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcu" style="margin-right:0.02778em">D</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mord" style=""><span class="delimsizing size1" style=""><span style="">)</span></span></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">804</span>        <span class="n">b_dsT</span> <span class="o">=</span> <span class="n">b_pT</span> <span class="o">*</span> <span class="p">(</span><span class="n">b_dpT</span> <span class="o">-</span> <span class="n">b_pdp</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:])</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-98'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-98'>#</a>
            </div>
            <p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.375916em;vertical-align:-0.530808em;"></span><span class="mord"><span class="mord coloredeq eqbh" style=""><span class="mord" style=""><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.845108em;"><span style="top:-2.6550000000000002em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqbv" style="margin-right:0.03588em">σ</span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em"></span></span><span style="top:-3.394em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqcs" style="">1</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.345em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mord" style=""><span class="mord mathnormal coloredeq eqcn" style="">d</span><span class="mord coloredeq eqcn" style=""><span class="mord mathnormal coloredeq eqcv" style="margin-right:0.07153em">K</span></span></span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.06696400000000002em;"><span style="top:-2.3053000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight coloredeq eqda" style=""><span class="mord mathnormal mtight" style="margin-right:0.05724em">j</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.530808em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1.0497100000000001em;vertical-align:-0.29971000000000003em;"></span><span class="mop"><span class="mop op-symbol small-op" style="position:relative;top:-0.0000050000000000050004em;">∑</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.16195399999999993em;"><span style="top:-2.40029em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight coloredeq eqcz" style=""><span class="mord mathnormal mtight" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.29971000000000003em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord mathnormal">d</span><span class="mord coloredeq eqbu" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.05764em">S</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.05764em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span><span class="mord coloredeq eqcj" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcw" style="">Q</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">806</span>        <span class="n">b_dk</span> <span class="o">+=</span> <span class="n">tl</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">b_dsT</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">b_qT</span><span class="o">.</span><span class="n">dtype</span><span class="p">),</span> <span class="n">tl</span><span class="o">.</span><span class="n">trans</span><span class="p">(</span><span class="n">b_qT</span><span class="p">),</span> <span class="n">out_dtype</span><span class="o">=</span><span class="n">HI_PRES_TL</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-99'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-99'>#</a>
            </div>
            <p>Increment pointers. </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">809</span>        <span class="n">offs_i</span> <span class="o">+=</span> <span class="n">BLOCK_Q</span>
<span class="lineno">810</span>        <span class="n">p_lse</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">advance</span><span class="p">(</span><span class="n">p_lse</span><span class="p">,</span> <span class="p">(</span><span class="n">BLOCK_Q</span><span class="p">,))</span>
<span class="lineno">811</span>        <span class="n">p_pdp</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">advance</span><span class="p">(</span><span class="n">p_pdp</span><span class="p">,</span> <span class="p">(</span><span class="n">BLOCK_Q</span><span class="p">,))</span>
<span class="lineno">812</span>        <span class="n">p_qT</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">advance</span><span class="p">(</span><span class="n">p_qT</span><span class="p">,</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_Q</span><span class="p">))</span>
<span class="lineno">813</span>        <span class="n">p_do</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">advance</span><span class="p">(</span><span class="n">p_do</span><span class="p">,</span> <span class="p">(</span><span class="n">BLOCK_Q</span><span class="p">,</span> <span class="mi">0</span><span class="p">))</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-100'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-100'>#</a>
            </div>
            <p>Return accumulated <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord coloredeq eqcn" style=""><span class="mord mathnormal" style="">d</span><span class="mord" style=""><span class="mord mathnormal coloredeq eqcv" style="margin-right:0.07153em">K</span></span></span></span></span></span></span> and <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord coloredeq eqcq" style=""><span class="mord mathnormal" style="">d</span><span class="mord" style=""><span class="mord mathnormal coloredeq eqcx" style="margin-right:0.22222em">V</span></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">816</span>    <span class="k">return</span> <span class="n">b_dk</span><span class="p">,</span> <span class="n">b_dv</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-101'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-101'>#</a>
            </div>
            <h4>Triton kernel to compute <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqcd" style=""><span class="mord mathnormal" style="">d</span><span class="mord" style=""><span class="mord coloredeq eqcj" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcw" style="">Q</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span></span></h4>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">819</span><span class="nd">@triton</span><span class="o">.</span><span class="n">autotune</span><span class="p">(</span><span class="n">_get_autotune_configs</span><span class="p">(</span><span class="n">inner_loop</span><span class="o">=</span><span class="s1">&#39;key&#39;</span><span class="p">),</span>
<span class="lineno">820</span>                 <span class="n">key</span><span class="o">=</span><span class="p">[</span><span class="s2">&quot;q_seq_len&quot;</span><span class="p">,</span> <span class="s2">&quot;kv_seq_len&quot;</span><span class="p">,</span> <span class="s2">&quot;d_head&quot;</span><span class="p">,</span> <span class="s2">&quot;n_groups&quot;</span><span class="p">,</span> <span class="s2">&quot;is_causal&quot;</span><span class="p">])</span>
<span class="lineno">821</span><span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
<span class="lineno">822</span><span class="k">def</span> <span class="nf">_attn_bwd_dq</span><span class="p">(</span><span class="n">t_q</span><span class="p">,</span> <span class="n">t_k</span><span class="p">,</span> <span class="n">t_v</span><span class="p">,</span> <span class="n">t_do</span><span class="p">,</span>
<span class="lineno">823</span>                 <span class="n">t_dq</span><span class="p">,</span>
<span class="lineno">824</span>                 <span class="n">t_lse</span><span class="p">,</span> <span class="n">t_pdp</span><span class="p">,</span>
<span class="lineno">825</span>                 <span class="n">q_seq_len</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span> <span class="n">kv_seq_len</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
<span class="lineno">826</span>                 <span class="n">n_groups</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span> <span class="n">d_head</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
<span class="lineno">827</span>                 <span class="n">is_causal</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
<span class="lineno">828</span>                 <span class="n">BLOCK_Q</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
<span class="lineno">829</span>                 <span class="n">BLOCK_K</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
<span class="lineno">830</span>                 <span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-102'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-102'>#</a>
            </div>
            
        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">835</span>    <span class="n">i</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">*</span> <span class="n">BLOCK_Q</span>
<span class="lineno">836</span>    <span class="n">z</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span> <span class="o">//</span> <span class="n">n_groups</span>
<span class="lineno">837</span>    <span class="n">g</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span> <span class="o">%</span> <span class="n">n_groups</span>  <span class="c1"># TODO</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-103'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-103'>#</a>
            </div>
            <p>Create block pointers </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">840</span>    <span class="n">p_q</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">make_block_ptr</span><span class="p">(</span><span class="n">t_q</span> <span class="o">+</span> <span class="n">z</span> <span class="o">*</span> <span class="n">n_groups</span> <span class="o">*</span> <span class="n">q_seq_len</span> <span class="o">*</span> <span class="n">d_head</span> <span class="o">+</span> <span class="n">g</span> <span class="o">*</span> <span class="n">q_seq_len</span> <span class="o">*</span> <span class="n">d_head</span><span class="p">,</span>
<span class="lineno">841</span>                            <span class="p">(</span><span class="n">q_seq_len</span><span class="p">,</span> <span class="n">d_head</span><span class="p">),</span>
<span class="lineno">842</span>                            <span class="p">(</span><span class="n">d_head</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
<span class="lineno">843</span>                            <span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span>
<span class="lineno">844</span>                            <span class="p">(</span><span class="n">BLOCK_Q</span><span class="p">,</span> <span class="n">d_head</span><span class="p">),</span>
<span class="lineno">845</span>                            <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">))</span>
<span class="lineno">846</span>    <span class="n">p_dq</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">make_block_ptr</span><span class="p">(</span><span class="n">t_dq</span> <span class="o">+</span> <span class="n">z</span> <span class="o">*</span> <span class="n">n_groups</span> <span class="o">*</span> <span class="n">q_seq_len</span> <span class="o">*</span> <span class="n">d_head</span> <span class="o">+</span> <span class="n">g</span> <span class="o">*</span> <span class="n">q_seq_len</span> <span class="o">*</span> <span class="n">d_head</span><span class="p">,</span>
<span class="lineno">847</span>                             <span class="p">(</span><span class="n">q_seq_len</span><span class="p">,</span> <span class="n">d_head</span><span class="p">),</span>
<span class="lineno">848</span>                             <span class="p">(</span><span class="n">d_head</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
<span class="lineno">849</span>                             <span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span>
<span class="lineno">850</span>                             <span class="p">(</span><span class="n">BLOCK_Q</span><span class="p">,</span> <span class="n">d_head</span><span class="p">),</span>
<span class="lineno">851</span>                             <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">))</span>
<span class="lineno">852</span>    <span class="n">p_do</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">make_block_ptr</span><span class="p">(</span><span class="n">t_do</span> <span class="o">+</span> <span class="n">z</span> <span class="o">*</span> <span class="n">n_groups</span> <span class="o">*</span> <span class="n">q_seq_len</span> <span class="o">*</span> <span class="n">d_head</span> <span class="o">+</span> <span class="n">g</span> <span class="o">*</span> <span class="n">q_seq_len</span> <span class="o">*</span> <span class="n">d_head</span><span class="p">,</span>
<span class="lineno">853</span>                             <span class="p">(</span><span class="n">q_seq_len</span><span class="p">,</span> <span class="n">d_head</span><span class="p">),</span>
<span class="lineno">854</span>                             <span class="p">(</span><span class="n">d_head</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
<span class="lineno">855</span>                             <span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span>
<span class="lineno">856</span>                             <span class="p">(</span><span class="n">BLOCK_Q</span><span class="p">,</span> <span class="n">d_head</span><span class="p">),</span>
<span class="lineno">857</span>                             <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">))</span>
<span class="lineno">858</span>    <span class="n">p_kT</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">make_block_ptr</span><span class="p">(</span><span class="n">t_k</span> <span class="o">+</span> <span class="n">z</span> <span class="o">*</span> <span class="n">kv_seq_len</span> <span class="o">*</span> <span class="n">d_head</span><span class="p">,</span>
<span class="lineno">859</span>                             <span class="p">(</span><span class="n">d_head</span><span class="p">,</span> <span class="n">kv_seq_len</span><span class="p">),</span>
<span class="lineno">860</span>                             <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">d_head</span><span class="p">),</span>
<span class="lineno">861</span>                             <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span>
<span class="lineno">862</span>                             <span class="p">(</span><span class="n">d_head</span><span class="p">,</span> <span class="n">BLOCK_K</span><span class="p">),</span>
<span class="lineno">863</span>                             <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
<span class="lineno">864</span>    <span class="n">p_vT</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">make_block_ptr</span><span class="p">(</span><span class="n">t_v</span> <span class="o">+</span> <span class="n">z</span> <span class="o">*</span> <span class="n">kv_seq_len</span> <span class="o">*</span> <span class="n">d_head</span><span class="p">,</span>
<span class="lineno">865</span>                             <span class="p">(</span><span class="n">d_head</span><span class="p">,</span> <span class="n">kv_seq_len</span><span class="p">),</span>
<span class="lineno">866</span>                             <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">d_head</span><span class="p">),</span>
<span class="lineno">867</span>                             <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span>
<span class="lineno">868</span>                             <span class="p">(</span><span class="n">d_head</span><span class="p">,</span> <span class="n">BLOCK_K</span><span class="p">),</span>
<span class="lineno">869</span>                             <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
<span class="lineno">870</span>    <span class="n">p_lse</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">make_block_ptr</span><span class="p">(</span><span class="n">t_lse</span> <span class="o">+</span> <span class="n">z</span> <span class="o">*</span> <span class="n">n_groups</span> <span class="o">*</span> <span class="n">q_seq_len</span> <span class="o">+</span> <span class="n">g</span> <span class="o">*</span> <span class="n">q_seq_len</span><span class="p">,</span>
<span class="lineno">871</span>                              <span class="p">(</span><span class="n">q_seq_len</span><span class="p">,),</span>
<span class="lineno">872</span>                              <span class="p">(</span><span class="mi">1</span><span class="p">,),</span>
<span class="lineno">873</span>                              <span class="p">(</span><span class="n">i</span><span class="p">,),</span>
<span class="lineno">874</span>                              <span class="p">(</span><span class="n">BLOCK_Q</span><span class="p">,),</span>
<span class="lineno">875</span>                              <span class="p">(</span><span class="mi">0</span><span class="p">,))</span>
<span class="lineno">876</span>    <span class="n">p_pdp</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">make_block_ptr</span><span class="p">(</span><span class="n">t_pdp</span> <span class="o">+</span> <span class="n">z</span> <span class="o">*</span> <span class="n">n_groups</span> <span class="o">*</span> <span class="n">q_seq_len</span> <span class="o">+</span> <span class="n">g</span> <span class="o">*</span> <span class="n">q_seq_len</span><span class="p">,</span>
<span class="lineno">877</span>                              <span class="p">(</span><span class="n">q_seq_len</span><span class="p">,),</span>
<span class="lineno">878</span>                              <span class="p">(</span><span class="mi">1</span><span class="p">,),</span>
<span class="lineno">879</span>                              <span class="p">(</span><span class="n">i</span><span class="p">,),</span>
<span class="lineno">880</span>                              <span class="p">(</span><span class="n">BLOCK_Q</span><span class="p">,),</span>
<span class="lineno">881</span>                              <span class="p">(</span><span class="mi">0</span><span class="p">,))</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-104'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-104'>#</a>
            </div>
            <p>Load <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8777699999999999em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqcj" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcw" style="">Q</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span>, <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.84444em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqcc" style=""><span class="mord mathnormal" style="">d</span><span class="mord" style=""><span class="mord coloredeq eqci" style=""><span class="mord mathnormal" style="margin-right:0.02778em">O</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span></span>, <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.83333em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqcf" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcu" style="margin-right:0.02778em">D</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span>, and <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.93858em;vertical-align:-0.24414em;"></span><span class="mord coloredeq eqbp" style=""><span class="mop" style=""><span class="mop" style=""><span style="">l</span><span style="">o</span><span style="margin-right:0.01389em">g</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.20696799999999996em;"><span style="top:-2.4558600000000004em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqct" style="">2</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.24414em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord coloredeq eqch" style=""><span class="mord mathnormal" style="">L</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span></span> outside the loop </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">884</span>    <span class="n">b_q</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">p_q</span><span class="p">,</span> <span class="n">boundary_check</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,),</span> <span class="n">padding_option</span><span class="o">=</span><span class="s2">&quot;zero&quot;</span><span class="p">)</span>
<span class="lineno">885</span>    <span class="n">b_do</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">p_do</span><span class="p">,</span> <span class="n">boundary_check</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,),</span> <span class="n">padding_option</span><span class="o">=</span><span class="s2">&quot;zero&quot;</span><span class="p">)</span>
<span class="lineno">886</span>    <span class="n">b_pdp</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">p_pdp</span><span class="p">,</span> <span class="n">boundary_check</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,),</span> <span class="n">padding_option</span><span class="o">=</span><span class="s2">&quot;zero&quot;</span><span class="p">)</span>
<span class="lineno">887</span>    <span class="n">b_lse</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">p_lse</span><span class="p">,</span> <span class="n">boundary_check</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,),</span> <span class="n">padding_option</span><span class="o">=</span><span class="s2">&quot;zero&quot;</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-105'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-105'>#</a>
            </div>
            <p>Initialize <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqbl" style=""><span class="mopen" style="">(</span><span class="mord" style=""><span class="mop coloredeq eqbr" style=""><span class="mop" style=""><span style="">l</span><span style="">o</span><span style="margin-right:0.01389em">g</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.20696799999999996em;"><span style="top:-2.4558600000000004em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqct" style="">2</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.24414em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em"></span><span class="mord coloredeq eqbr" style=""><span class="mord mathnormal coloredeq eqcy" style="">e</span></span></span><span class="mclose" style="">)</span><span class="mord" style=""><span class="mord mathnormal coloredeq eqco" style="">d</span><span class="mord coloredeq eqco" style=""><span class="mord mathnormal coloredeq eqcw" style="">Q</span></span></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">890</span>    <span class="n">b_dq</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="n">BLOCK_Q</span><span class="p">,</span> <span class="n">d_head</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">HI_PRES_TL</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-106'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-106'>#</a>
            </div>
            <p><span ><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:2.463782em;vertical-align:-1.413777em;"></span><span class="mord coloredeq eqj" style=""><span class="mord mathnormal" style="">d</span><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.03588em">q</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel" style="">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mop op-limits" style=""><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.050005em;"><span style="top:-1.8723309999999997em;margin-left:0em;"><span class="pstrut" style="height:3.05em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span><span style="top:-3.0500049999999996em;"><span class="pstrut" style="height:3.05em;"></span><span><span class="mop op-symbol large-op" style="">∑</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:1.413777em;"><span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord mathnormal" style="">d</span><span class="mord" style=""><span class="mord coloredeq eqbu" style=""><span class="mord mathnormal" style="margin-right:0.05764em">S</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.05764em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.03148em">k</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.03148em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel" style="">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mop op-limits" style=""><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.050005em;"><span style="top:-1.8723309999999997em;margin-left:0em;"><span class="pstrut" style="height:3.05em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span><span style="top:-3.0500049999999996em;"><span class="pstrut" style="height:3.05em;"></span><span><span class="mop op-symbol large-op" style="">∑</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:1.413777em;"><span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord coloredeq eqbt" style=""><span class="mord mathnormal" style="margin-right:0.13889em">P</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span><span class="mord" style=""><span class="delimsizing size1" style=""><span style="">(</span></span></span><span class="mord mathnormal" style="">d</span><span class="mord" style=""><span class="mord coloredeq eqbt" style=""><span class="mord mathnormal" style="margin-right:0.13889em">P</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin" style="">−</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord" style=""><span class="mord coloredeq eqcf" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcu" style="margin-right:0.02778em">D</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mord" style=""><span class="delimsizing size1" style=""><span style="">)</span></span></span><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.03148em">k</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.03148em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">894</span>    <span class="k">if</span> <span class="n">is_causal</span><span class="p">:</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-107'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-107'>#</a>
            </div>
            <p>Compute <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqco" style=""><span class="mord mathnormal" style="">d</span><span class="mord" style=""><span class="mord mathnormal coloredeq eqcw" style="">Q</span></span></span></span></span></span></span> for masked (diagonal) blocks. </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">896</span>        <span class="n">b_dq</span> <span class="o">=</span> <span class="n">_attn_bwd_dq_inner</span><span class="p">(</span><span class="n">b_dq</span><span class="p">,</span> <span class="n">b_q</span><span class="p">,</span> <span class="n">p_kT</span><span class="p">,</span> <span class="n">p_vT</span><span class="p">,</span>
<span class="lineno">897</span>                                  <span class="n">b_do</span><span class="p">,</span> <span class="n">b_lse</span><span class="p">,</span> <span class="n">b_pdp</span><span class="p">,</span>
<span class="lineno">898</span>                                  <span class="n">BLOCK_Q</span><span class="p">,</span> <span class="n">BLOCK_K</span><span class="p">,</span>
<span class="lineno">899</span>                                  <span class="n">i</span><span class="o">=</span><span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="o">=</span><span class="n">i</span><span class="p">,</span>
<span class="lineno">900</span>                                  <span class="n">steps</span><span class="o">=</span><span class="n">BLOCK_Q</span> <span class="o">//</span> <span class="n">BLOCK_K</span><span class="p">,</span>
<span class="lineno">901</span>                                  <span class="n">MASK</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="lineno">902</span>                                  <span class="n">q_seq_len</span><span class="o">=</span><span class="n">q_seq_len</span><span class="p">,</span>
<span class="lineno">903</span>                                  <span class="n">kv_seq_len</span><span class="o">=</span><span class="n">kv_seq_len</span>
<span class="lineno">904</span>                                  <span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-108'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-108'>#</a>
            </div>
            <p>Compute for other blocks </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">907</span>        <span class="n">b_dq</span> <span class="o">=</span> <span class="n">_attn_bwd_dq_inner</span><span class="p">(</span><span class="n">b_dq</span><span class="p">,</span> <span class="n">b_q</span><span class="p">,</span> <span class="n">p_kT</span><span class="p">,</span> <span class="n">p_vT</span><span class="p">,</span>
<span class="lineno">908</span>                                  <span class="n">b_do</span><span class="p">,</span> <span class="n">b_lse</span><span class="p">,</span> <span class="n">b_pdp</span><span class="p">,</span>
<span class="lineno">909</span>                                  <span class="n">BLOCK_Q</span><span class="p">,</span> <span class="n">BLOCK_K</span><span class="p">,</span>
<span class="lineno">910</span>                                  <span class="n">i</span><span class="o">=</span><span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="o">=</span><span class="n">tl</span><span class="o">.</span><span class="n">full</span><span class="p">([],</span> <span class="mi">0</span><span class="p">,</span> <span class="n">tl</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>  <span class="c1"># type: ignore</span>
<span class="lineno">911</span>                                  <span class="n">steps</span><span class="o">=</span><span class="n">i</span> <span class="o">//</span> <span class="n">BLOCK_K</span><span class="p">,</span>
<span class="lineno">912</span>                                  <span class="n">MASK</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="lineno">913</span>                                  <span class="n">q_seq_len</span><span class="o">=</span><span class="n">q_seq_len</span><span class="p">,</span>
<span class="lineno">914</span>                                  <span class="n">kv_seq_len</span><span class="o">=</span><span class="n">kv_seq_len</span>
<span class="lineno">915</span>                                  <span class="p">)</span>
<span class="lineno">916</span>    <span class="k">else</span><span class="p">:</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-109'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-109'>#</a>
            </div>
            <p>Iterate through all <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqcv" style=""><span class="mord mathnormal" style="margin-right:0.07153em">K</span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">918</span>        <span class="n">b_dq</span> <span class="o">=</span> <span class="n">_attn_bwd_dq_inner</span><span class="p">(</span><span class="n">b_dq</span><span class="p">,</span> <span class="n">b_q</span><span class="p">,</span> <span class="n">p_kT</span><span class="p">,</span> <span class="n">p_vT</span><span class="p">,</span>
<span class="lineno">919</span>                                  <span class="n">b_do</span><span class="p">,</span> <span class="n">b_lse</span><span class="p">,</span> <span class="n">b_pdp</span><span class="p">,</span>
<span class="lineno">920</span>                                  <span class="n">BLOCK_Q</span><span class="p">,</span> <span class="n">BLOCK_K</span><span class="p">,</span>
<span class="lineno">921</span>                                  <span class="n">i</span><span class="o">=</span><span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="o">=</span><span class="n">tl</span><span class="o">.</span><span class="n">full</span><span class="p">([],</span> <span class="mi">0</span><span class="p">,</span> <span class="n">tl</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>  <span class="c1"># type: ignore</span>
<span class="lineno">922</span>                                  <span class="n">steps</span><span class="o">=</span><span class="n">tl</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">kv_seq_len</span><span class="p">,</span> <span class="n">BLOCK_K</span><span class="p">),</span>
<span class="lineno">923</span>                                  <span class="n">MASK</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="lineno">924</span>                                  <span class="n">q_seq_len</span><span class="o">=</span><span class="n">q_seq_len</span><span class="p">,</span>
<span class="lineno">925</span>                                  <span class="n">kv_seq_len</span><span class="o">=</span><span class="n">kv_seq_len</span>
<span class="lineno">926</span>                                  <span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-110'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-110'>#</a>
            </div>
            <p><code  class="highlight"><span></span><span class="n">b_dq</span></code>
 stores <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqbl" style=""><span class="mopen" style="">(</span><span class="mord" style=""><span class="mop coloredeq eqbr" style=""><span class="mop" style=""><span style="">l</span><span style="">o</span><span style="margin-right:0.01389em">g</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.20696799999999996em;"><span style="top:-2.4558600000000004em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqct" style="">2</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.24414em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em"></span><span class="mord coloredeq eqbr" style=""><span class="mord mathnormal coloredeq eqcy" style="">e</span></span></span><span class="mclose" style="">)</span><span class="mord" style=""><span class="mord mathnormal coloredeq eqco" style="">d</span><span class="mord coloredeq eqco" style=""><span class="mord mathnormal coloredeq eqcw" style="">Q</span></span></span></span></span></span></span></span> so multiply by <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.93858em;vertical-align:-0.24414em;"></span><span class="mop"><span class="mop">lo<span style="margin-right:0.01389em;">g</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.057252em;"><span style="top:-2.4558600000000004em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight coloredeq eqcy" style=""><span class="mord mathnormal mtight" style="">e</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.24414em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqct" style=""><span class="mord" style="">2</span></span></span></span></span></span> to get <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqco" style=""><span class="mord mathnormal" style="">d</span><span class="mord" style=""><span class="mord mathnormal coloredeq eqcw" style="">Q</span></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">929</span>    <span class="n">b_dq</span> <span class="o">*=</span> <span class="mf">0.6931471824645996</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-111'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-111'>#</a>
            </div>
            <p>Save <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqco" style=""><span class="mord mathnormal" style="">d</span><span class="mord" style=""><span class="mord mathnormal coloredeq eqcw" style="">Q</span></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">932</span>    <span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">p_dq</span><span class="p">,</span> <span class="n">b_dq</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">t_dq</span><span class="o">.</span><span class="n">type</span><span class="o">.</span><span class="n">element_ty</span><span class="p">),</span> <span class="n">boundary_check</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,))</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-112'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-112'>#</a>
            </div>
            <h4>Inner loop to calculate <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqcd" style=""><span class="mord mathnormal" style="">d</span><span class="mord" style=""><span class="mord coloredeq eqcj" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcw" style="">Q</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span></span></h4>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">935</span><span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
<span class="lineno">936</span><span class="k">def</span> <span class="nf">_attn_bwd_dq_inner</span><span class="p">(</span><span class="n">b_dq</span><span class="p">,</span> <span class="n">b_q</span><span class="p">,</span> <span class="n">p_kT</span><span class="p">,</span> <span class="n">p_vT</span><span class="p">,</span>
<span class="lineno">937</span>                       <span class="n">b_do</span><span class="p">,</span> <span class="n">b_lse</span><span class="p">,</span> <span class="n">b_pdp</span><span class="p">,</span>
<span class="lineno">938</span>                       <span class="n">BLOCK_Q</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span> <span class="n">BLOCK_K</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
<span class="lineno">939</span>                       <span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="p">,</span> <span class="n">steps</span><span class="p">,</span>
<span class="lineno">940</span>                       <span class="n">MASK</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
<span class="lineno">941</span>                       <span class="n">q_seq_len</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
<span class="lineno">942</span>                       <span class="n">kv_seq_len</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-113'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-113'>#</a>
            </div>
            <p>Offsets </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">948</span>    <span class="n">offs_i</span> <span class="o">=</span> <span class="n">i</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_Q</span><span class="p">)</span>
<span class="lineno">949</span>    <span class="n">offs_j</span> <span class="o">=</span> <span class="n">j</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_K</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-114'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-114'>#</a>
            </div>
            <p>Move the pointers </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">952</span>    <span class="n">p_kT</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">advance</span><span class="p">(</span><span class="n">p_kT</span><span class="p">,</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">j</span><span class="p">))</span>
<span class="lineno">953</span>    <span class="n">p_vT</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">advance</span><span class="p">(</span><span class="n">p_vT</span><span class="p">,</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">j</span><span class="p">))</span>
<span class="lineno">954</span>
<span class="lineno">955</span>    <span class="n">tl</span><span class="o">.</span><span class="n">static_assert</span><span class="p">(</span><span class="n">BLOCK_Q</span> <span class="o">%</span> <span class="n">BLOCK_K</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> <span class="s1">&#39;BLOCK_Q must be divisible by BLOCK_K&#39;</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-115'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-115'>#</a>
            </div>
            <p>Iterate over <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqcv" style=""><span class="mord mathnormal" style="margin-right:0.07153em">K</span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">958</span>    <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">steps</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-116'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-116'>#</a>
            </div>
            <p>Load <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.200669em;vertical-align:-0.286108em;"></span><span class="mord coloredeq eqbx" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord coloredeq eqcg" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcv" style="margin-right:0.07153em">K</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.914561em;"><span style="top:-3.1362300000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.13889em">T</span></span></span></span></span></span></span></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">960</span>        <span class="n">b_kT</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">p_kT</span><span class="p">,</span> <span class="n">boundary_check</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,),</span> <span class="n">padding_option</span><span class="o">=</span><span class="s2">&quot;zero&quot;</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-117'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-117'>#</a>
            </div>
            <p>Load <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.200669em;vertical-align:-0.286108em;"></span><span class="mord coloredeq eqbz" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord coloredeq eqck" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcx" style="margin-right:0.22222em">V</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.914561em;"><span style="top:-3.1362300000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.13889em">T</span></span></span></span></span></span></span></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">962</span>        <span class="n">b_vT</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">p_vT</span><span class="p">,</span> <span class="n">boundary_check</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,),</span> <span class="n">padding_option</span><span class="o">=</span><span class="s2">&quot;zero&quot;</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-118'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-118'>#</a>
            </div>
            <p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.036108em;vertical-align:-0.286108em;"></span><span class="mopen">(</span><span class="mord coloredeq eqbr" style=""><span class="mop" style=""><span class="mop" style=""><span style="">l</span><span style="">o</span><span style="margin-right:0.01389em">g</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.20696799999999996em;"><span style="top:-2.4558600000000004em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqct" style="">2</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.24414em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord mathnormal coloredeq eqcy" style="">e</span></span></span><span class="mclose">)</span><span class="mord coloredeq eqbu" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.05764em">S</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.05764em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1.200669em;vertical-align:-0.286108em;"></span><span class="mord coloredeq eqbv" style=""><span class="mord mathnormal" style="margin-right:0.03588em">σ</span></span><span class="mopen">(</span><span class="mord coloredeq eqbr" style=""><span class="mop" style=""><span class="mop" style=""><span style="">l</span><span style="">o</span><span style="margin-right:0.01389em">g</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.20696799999999996em;"><span style="top:-2.4558600000000004em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqct" style="">2</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.24414em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord mathnormal coloredeq eqcy" style="">e</span></span></span><span class="mclose">)</span><span class="mord coloredeq eqcj" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcw" style="">Q</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mord coloredeq eqbx" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord coloredeq eqcg" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcv" style="margin-right:0.07153em">K</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.914561em;"><span style="top:-3.1362300000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.13889em">T</span></span></span></span></span></span></span></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">965</span>        <span class="n">b_s</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">b_q</span><span class="p">,</span> <span class="n">b_kT</span><span class="p">,</span> <span class="n">out_dtype</span><span class="o">=</span><span class="n">HI_PRES_TL</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-119'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-119'>#</a>
            </div>
            <span ><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:6.803331em;vertical-align:-3.1516655em;"></span><span class="mord coloredeq eqd" style=""><span class="mord" style=""><span class="mtable"><span class="col-align-r"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:3.6516655em;"><span style="top:-5.6983345000000005em;"><span class="pstrut" style="height:3.565em;"></span><span class="mord" style=""><span class="mord" style=""><span class="mord coloredeq eqbt" style=""><span class="mord mathnormal" style="margin-right:0.13889em">P</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span></span><span style="top:-2.997334499999999em;"><span class="pstrut" style="height:3.565em;"></span><span class="mord" style=""></span></span><span style="top:-1.0733345em;"><span class="pstrut" style="height:3.565em;"></span><span class="mord" style=""></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:3.1516655em;"><span></span></span></span></span></span><span class="col-align-l"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:3.6516655em;"><span style="top:-5.6983345000000005em;"><span class="pstrut" style="height:3.565em;"></span><span class="mord" style=""><span class="mord" style=""></span><span class="mspace" style="margin-right:0.2777777777777778em"></span><span class="mrel" style="">=</span><span class="mspace" style="margin-right:0.2777777777777778em"></span><span class="mord" style=""><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.518331em;"><span style="top:-2.3139999999999996em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style=""><span class="mord" style=""><span class="mord coloredeq eqch" style=""><span class="mord mathnormal" style="">L</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em"></span></span><span style="top:-3.677em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcy" style="">e</span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.841331em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqbu" style=""><span class="mord mathnormal mtight" style="margin-right:0.05764em">S</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3280857142857143em;"><span style="top:-2.357em;margin-left:-0.05764em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2818857142857143em;"><span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.8360000000000001em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span><span style="top:-2.997334499999999em;"><span class="pstrut" style="height:3.565em;"></span><span class="mord" style=""><span class="mord" style=""></span><span class="mspace" style="margin-right:0.2777777777777778em"></span><span class="mrel" style="">=</span><span class="mspace" style="margin-right:0.2777777777777778em"></span><span class="mord" style=""><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.565em;"><span style="top:-2.314em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord coloredeq eqct" style="">2</span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.7799659999999999em;"><span style="top:-2.9938580000000004em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mop mtight coloredeq eqbp" style=""><span class="mop mtight" style=""><span class="mtight" style="">l</span><span class="mtight" style="">o</span><span class="mtight" style="margin-right:0.01389em">g</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.19444571428571428em;"><span style="top:-2.2341314285714287em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqct" style="">2</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.26586857142857145em;"><span></span></span></span></span></span></span><span class="mspace mtight" style="margin-right:0.19516666666666668em"></span><span class="mord mtight coloredeq eqbp" style=""><span class="mord mtight coloredeq eqch" style=""><span class="mord mathnormal mtight" style="">L</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3280857142857143em;"><span style="top:-2.357em;margin-left:0em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.143em;"><span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em"></span></span><span style="top:-3.677em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord coloredeq eqct" style="">2</span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8879999999999999em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mopen mtight" style="">(</span><span class="mord mathnormal mtight" style="margin-right:0.01968em">l</span><span class="mord mathnormal mtight" style="">o</span><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.03588em">g</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31731428571428577em;"><span style="top:-2.357em;margin-left:-0.03588em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqct" style="">2</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.143em;"><span></span></span></span></span></span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcy" style="">e</span></span><span class="mclose mtight" style="">)</span><span class="mord mtight" style=""><span class="mord mtight coloredeq eqbu" style=""><span class="mord mathnormal mtight" style="margin-right:0.05764em">S</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3280857142857143em;"><span style="top:-2.357em;margin-left:-0.05764em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2818857142857143em;"><span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.686em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span><span style="top:-1.0733345em;"><span class="pstrut" style="height:3.565em;"></span><span class="mord" style=""><span class="mord" style=""></span><span class="mspace" style="margin-right:0.2777777777777778em"></span><span class="mrel" style="">=</span><span class="mspace" style="margin-right:0.2777777777777778em"></span><span class="mord" style=""><span class="mord" style=""><span class="mord coloredeq eqct" style="">2</span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.9379999999999998em;"><span style="top:-3.113em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mopen mtight" style="">(</span><span class="mord mathnormal mtight" style="margin-right:0.01968em">l</span><span class="mord mathnormal mtight" style="">o</span><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.03588em">g</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31731428571428577em;"><span style="top:-2.357em;margin-left:-0.03588em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqct" style="">2</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.143em;"><span></span></span></span></span></span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcy" style="">e</span></span><span class="mclose mtight" style="">)</span><span class="mord mtight" style=""><span class="mord mtight coloredeq eqbu" style=""><span class="mord mathnormal mtight" style="margin-right:0.05764em">S</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3280857142857143em;"><span style="top:-2.357em;margin-left:-0.05764em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2818857142857143em;"><span></span></span></span></span></span></span></span><span class="mbin mtight" style="">−</span><span class="mord mtight" style=""><span class="mop mtight coloredeq eqbp" style=""><span class="mop mtight" style=""><span class="mtight" style="">l</span><span class="mtight" style="">o</span><span class="mtight" style="margin-right:0.01389em">g</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.19444571428571428em;"><span style="top:-2.2341314285714287em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqct" style="">2</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.26586857142857145em;"><span></span></span></span></span></span></span><span class="mspace mtight" style="margin-right:0.19516666666666668em"></span><span class="mord mtight coloredeq eqbp" style=""><span class="mord mtight coloredeq eqch" style=""><span class="mord mathnormal mtight" style="">L</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3280857142857143em;"><span style="top:-2.357em;margin-left:0em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.143em;"><span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:3.1516655em;"><span></span></span></span></span></span></span></span></span></span></span></span></span></span><p> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">974</span>        <span class="n">b_p</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">math</span><span class="o">.</span><span class="n">exp2</span><span class="p">(</span><span class="n">b_s</span> <span class="o">-</span> <span class="n">b_lse</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">])</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-120'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-120'>#</a>
            </div>
            <p>Autoregressive masking </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">977</span>        <span class="k">if</span> <span class="n">MASK</span><span class="p">:</span>
<span class="lineno">978</span>            <span class="n">causal_mask</span> <span class="o">=</span> <span class="p">(</span><span class="n">offs_i</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">&gt;=</span> <span class="n">offs_j</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:])</span>
<span class="lineno">979</span>            <span class="n">b_p</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">causal_mask</span><span class="p">,</span> <span class="n">b_p</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-121'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-121'>#</a>
            </div>
            <p>Mask out if the block is beyond the end of <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8777699999999999em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqcj" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcw" style="">Q</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">982</span>        <span class="n">j_mask</span> <span class="o">=</span> <span class="n">offs_j</span> <span class="o">&lt;</span> <span class="n">kv_seq_len</span>
<span class="lineno">983</span>        <span class="n">b_p</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">j_mask</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:],</span> <span class="n">b_p</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-122'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-122'>#</a>
            </div>
            <p><span ><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:2.463782em;vertical-align:-1.413777em;"></span><span class="mord coloredeq eqj" style=""><span class="mord mathnormal" style="">d</span><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.03588em">q</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel" style="">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mop op-limits" style=""><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.050005em;"><span style="top:-1.8723309999999997em;margin-left:0em;"><span class="pstrut" style="height:3.05em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span><span style="top:-3.0500049999999996em;"><span class="pstrut" style="height:3.05em;"></span><span><span class="mop op-symbol large-op" style="">∑</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:1.413777em;"><span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord mathnormal" style="">d</span><span class="mord" style=""><span class="mord coloredeq eqbu" style=""><span class="mord mathnormal" style="margin-right:0.05764em">S</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.05764em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.03148em">k</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.03148em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel" style="">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mop op-limits" style=""><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.050005em;"><span style="top:-1.8723309999999997em;margin-left:0em;"><span class="pstrut" style="height:3.05em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span><span style="top:-3.0500049999999996em;"><span class="pstrut" style="height:3.05em;"></span><span><span class="mop op-symbol large-op" style="">∑</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:1.413777em;"><span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord coloredeq eqbt" style=""><span class="mord mathnormal" style="margin-right:0.13889em">P</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span><span class="mord" style=""><span class="delimsizing size1" style=""><span style="">(</span></span></span><span class="mord mathnormal" style="">d</span><span class="mord" style=""><span class="mord coloredeq eqbt" style=""><span class="mord mathnormal" style="margin-right:0.13889em">P</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin" style="">−</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord" style=""><span class="mord coloredeq eqcf" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcu" style="margin-right:0.02778em">D</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mord" style=""><span class="delimsizing size1" style=""><span style="">)</span></span></span><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.03148em">k</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.03148em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre></pre></div>
        </div>
    </div>
    <div class='section' id='section-123'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-123'>#</a>
            </div>
            <p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.980548em;vertical-align:-0.286108em;"></span><span class="mord mathnormal">d</span><span class="mord coloredeq eqbt" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.13889em">P</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1.200669em;vertical-align:-0.286108em;"></span><span class="mord coloredeq eqcc" style=""><span class="mord mathnormal" style="">d</span><span class="mord" style=""><span class="mord coloredeq eqci" style=""><span class="mord mathnormal" style="margin-right:0.02778em">O</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span><span class="mord coloredeq eqbz" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord coloredeq eqck" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcx" style="margin-right:0.22222em">V</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.914561em;"><span style="top:-3.1362300000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.13889em">T</span></span></span></span></span></span></span></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">988</span>        <span class="n">b_dp</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">b_do</span><span class="p">,</span> <span class="n">b_vT</span><span class="p">,</span> <span class="n">out_dtype</span><span class="o">=</span><span class="n">HI_PRES_TL</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">HI_PRES_TL</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-124'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-124'>#</a>
            </div>
            <p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.20001em;vertical-align:-0.35001em;"></span><span class="mord coloredeq eqt" style=""><span class="mord mathnormal" style="">d</span><span class="mord" style=""><span class="mord coloredeq eqbu" style=""><span class="mord mathnormal" style="margin-right:0.05764em">S</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.05764em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel" style="">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mord" style=""><span class="mord coloredeq eqbt" style=""><span class="mord mathnormal" style="margin-right:0.13889em">P</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span><span class="mord" style=""><span class="delimsizing size1" style=""><span style="">(</span></span></span><span class="mord mathnormal" style="">d</span><span class="mord" style=""><span class="mord coloredeq eqbt" style=""><span class="mord mathnormal" style="margin-right:0.13889em">P</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin" style="">−</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord" style=""><span class="mord coloredeq eqcf" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcu" style="margin-right:0.02778em">D</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mord" style=""><span class="delimsizing size1" style=""><span style="">)</span></span></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">990</span>        <span class="n">b_ds</span> <span class="o">=</span> <span class="n">b_p</span> <span class="o">*</span> <span class="p">(</span><span class="n">b_dp</span> <span class="o">-</span> <span class="n">b_pdp</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">])</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-125'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-125'>#</a>
            </div>
            <p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mopen">(</span><span class="mord coloredeq eqbr" style=""><span class="mop" style=""><span class="mop" style=""><span style="">l</span><span style="">o</span><span style="margin-right:0.01389em">g</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.20696799999999996em;"><span style="top:-2.4558600000000004em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqct" style="">2</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.24414em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord mathnormal coloredeq eqcy" style="">e</span></span></span><span class="mclose">)</span><span class="mord coloredeq eqcd" style=""><span class="mord mathnormal" style="">d</span><span class="mord" style=""><span class="mord coloredeq eqcj" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcw" style="">Q</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1.185818em;vertical-align:-0.43581800000000004em;"></span><span class="mop"><span class="mop op-symbol small-op" style="position:relative;top:-0.0000050000000000050004em;">∑</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.16195399999999993em;"><span style="top:-2.40029em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight coloredeq eqda" style=""><span class="mord mathnormal mtight" style="margin-right:0.05724em">j</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.43581800000000004em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord mathnormal">d</span><span class="mord coloredeq eqbu" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.05764em">S</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.05764em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcz" style="">i</span></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span><span class="mord coloredeq eqbe" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqbv" style="margin-right:0.03588em">σ</span></span><span class="mopen" style="">(</span><span class="mord" style=""><span class="mop coloredeq eqbr" style=""><span class="mop" style=""><span style="">l</span><span style="">o</span><span style="margin-right:0.01389em">g</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.20696799999999996em;"><span style="top:-2.4558600000000004em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqct" style="">2</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.24414em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em"></span><span class="mord coloredeq eqbr" style=""><span class="mord mathnormal coloredeq eqcy" style="">e</span></span></span><span class="mclose" style="">)</span><span class="mord" style=""><span class="mord coloredeq eqcg" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcv" style="margin-right:0.07153em">K</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqda" style="margin-right:0.05724em">j</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">992</span>        <span class="n">b_dq</span> <span class="o">+=</span> <span class="n">tl</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">b_ds</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">b_kT</span><span class="o">.</span><span class="n">dtype</span><span class="p">),</span> <span class="n">tl</span><span class="o">.</span><span class="n">trans</span><span class="p">(</span><span class="n">b_kT</span><span class="p">),</span> <span class="n">out_dtype</span><span class="o">=</span><span class="n">HI_PRES_TL</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-126'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-126'>#</a>
            </div>
            <p>Increment pointers. </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">995</span>        <span class="n">offs_j</span> <span class="o">+=</span> <span class="n">BLOCK_K</span>
<span class="lineno">996</span>        <span class="n">p_kT</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">advance</span><span class="p">(</span><span class="n">p_kT</span><span class="p">,</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_K</span><span class="p">))</span>
<span class="lineno">997</span>        <span class="n">p_vT</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">advance</span><span class="p">(</span><span class="n">p_vT</span><span class="p">,</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_K</span><span class="p">))</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-127'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-127'>#</a>
            </div>
            <p>Return accumulated <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqco" style=""><span class="mord mathnormal" style="">d</span><span class="mord" style=""><span class="mord mathnormal coloredeq eqcw" style="">Q</span></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">1000</span>    <span class="k">return</span> <span class="n">b_dq</span></pre></div>
        </div>
    </div>
    <div class='footer'>
        <a href="https://labml.ai">labml.ai</a>
    </div>
</div>
<script src=../../interactive.js?v=1"></script>
<script>
    function handleImages() {
        var images = document.querySelectorAll('p>img')

        for (var i = 0; i < images.length; ++i) {
            handleImage(images[i])
        }
    }

    function handleImage(img) {
        img.parentElement.style.textAlign = 'center'

        var modal = document.createElement('div')
        modal.id = 'modal'

        var modalContent = document.createElement('div')
        modal.appendChild(modalContent)

        var modalImage = document.createElement('img')
        modalContent.appendChild(modalImage)

        var span = document.createElement('span')
        span.classList.add('close')
        span.textContent = 'x'
        modal.appendChild(span)

        img.onclick = function () {
            console.log('clicked')
            document.body.appendChild(modal)
            modalImage.src = img.src
        }

        span.onclick = function () {
            document.body.removeChild(modal)
        }
    }

    handleImages()
</script>
</body>
</html>